Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
·
8d30b62
1
Parent(s):
24b4b28
Update with h2oGPT hash 880439992dce589c865d5ba3a4f183902f6fc8ec
Browse files- client_test.py +81 -52
- create_data.py +1818 -0
- finetune.py +4 -378
- generate.py +302 -110
- gpt4all_llm.py +119 -0
- gpt_langchain.py +1076 -0
- gradio_runner.py +634 -46
- gradio_themes.py +3 -1
- h2oai_pipeline.py +54 -0
- loaders.py +50 -0
- prompter.py +370 -1
- requirements.txt +54 -3
- utils.py +477 -6
client_test.py
CHANGED
@@ -36,84 +36,113 @@ Loaded as API: https://gpt.h2o.ai ✔
|
|
36 |
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.'}
|
37 |
|
38 |
"""
|
|
|
|
|
|
|
|
|
39 |
|
40 |
debug = False
|
41 |
|
42 |
-
import os
|
43 |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
44 |
|
45 |
|
46 |
-
def get_client():
|
47 |
from gradio_client import Client
|
48 |
|
49 |
-
client = Client(os.getenv('HOST', "http://localhost:7860"))
|
50 |
if debug:
|
51 |
print(client.view_api(all_endpoints=True))
|
52 |
return client
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def test_client_basic():
|
56 |
-
return
|
57 |
-
|
58 |
-
|
59 |
-
def
|
60 |
-
|
61 |
-
|
62 |
-
context = ''
|
63 |
-
# streaming output is supported, loops over and outputs each generation in streaming mode
|
64 |
-
# but leave stream_output=False for simple input/output mode
|
65 |
-
stream_output = False
|
66 |
-
temperature = 0.1
|
67 |
-
top_p = 0.75
|
68 |
-
top_k = 40
|
69 |
-
num_beams = 1
|
70 |
-
max_new_tokens = 50
|
71 |
-
min_new_tokens = 0
|
72 |
-
early_stopping = False
|
73 |
-
max_time = 20
|
74 |
-
repetition_penalty = 1.0
|
75 |
-
num_return_sequences = 1
|
76 |
-
do_sample = True
|
77 |
-
# only these 2 below used if pass chat=False
|
78 |
-
chat = False
|
79 |
-
iinput_nochat = ''
|
80 |
-
|
81 |
-
args = [instruction,
|
82 |
-
iinput,
|
83 |
-
context,
|
84 |
-
stream_output,
|
85 |
-
prompt_type,
|
86 |
-
temperature,
|
87 |
-
top_p,
|
88 |
-
top_k,
|
89 |
-
num_beams,
|
90 |
-
max_new_tokens,
|
91 |
-
min_new_tokens,
|
92 |
-
early_stopping,
|
93 |
-
max_time,
|
94 |
-
repetition_penalty,
|
95 |
-
num_return_sequences,
|
96 |
-
do_sample,
|
97 |
-
chat,
|
98 |
-
instruction_nochat,
|
99 |
-
iinput_nochat,
|
100 |
-
]
|
101 |
api_name = '/submit_nochat'
|
102 |
-
client = get_client()
|
103 |
res = client.predict(
|
104 |
*tuple(args),
|
105 |
api_name=api_name,
|
106 |
)
|
107 |
-
res_dict = dict(
|
|
|
108 |
print(res_dict)
|
109 |
return res_dict
|
110 |
|
111 |
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
|
116 |
def md_to_text(md):
|
|
|
117 |
html = markdown.markdown(md)
|
118 |
soup = BeautifulSoup(html, features='html.parser')
|
119 |
return soup.get_text()
|
|
|
36 |
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.'}
|
37 |
|
38 |
"""
|
39 |
+
import time
|
40 |
+
import os
|
41 |
+
import markdown # pip install markdown
|
42 |
+
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
43 |
|
44 |
debug = False
|
45 |
|
|
|
46 |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
47 |
|
48 |
|
49 |
+
def get_client(serialize=True):
|
50 |
from gradio_client import Client
|
51 |
|
52 |
+
client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize)
|
53 |
if debug:
|
54 |
print(client.view_api(all_endpoints=True))
|
55 |
return client
|
56 |
|
57 |
|
58 |
+
def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50):
|
59 |
+
from collections import OrderedDict
|
60 |
+
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
61 |
+
iinput='', # only for chat=True
|
62 |
+
context='',
|
63 |
+
# streaming output is supported, loops over and outputs each generation in streaming mode
|
64 |
+
# but leave stream_output=False for simple input/output mode
|
65 |
+
stream_output=stream_output,
|
66 |
+
prompt_type=prompt_type,
|
67 |
+
temperature=0.1,
|
68 |
+
top_p=0.75,
|
69 |
+
top_k=40,
|
70 |
+
num_beams=1,
|
71 |
+
max_new_tokens=max_new_tokens,
|
72 |
+
min_new_tokens=0,
|
73 |
+
early_stopping=False,
|
74 |
+
max_time=20,
|
75 |
+
repetition_penalty=1.0,
|
76 |
+
num_return_sequences=1,
|
77 |
+
do_sample=True,
|
78 |
+
chat=chat,
|
79 |
+
instruction_nochat=prompt if not chat else '',
|
80 |
+
iinput_nochat='', # only for chat=False
|
81 |
+
langchain_mode='Disabled',
|
82 |
+
)
|
83 |
+
if chat:
|
84 |
+
# add chatbot output on end. Assumes serialize=False
|
85 |
+
kwargs.update(dict(chatbot=[['', None]]))
|
86 |
+
|
87 |
+
return kwargs, list(kwargs.values())
|
88 |
+
|
89 |
+
|
90 |
def test_client_basic():
|
91 |
+
return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
|
92 |
+
|
93 |
+
|
94 |
+
def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
95 |
+
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
|
96 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
api_name = '/submit_nochat'
|
98 |
+
client = get_client(serialize=True)
|
99 |
res = client.predict(
|
100 |
*tuple(args),
|
101 |
api_name=api_name,
|
102 |
)
|
103 |
+
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
104 |
+
response=md_to_text(res))
|
105 |
print(res_dict)
|
106 |
return res_dict
|
107 |
|
108 |
|
109 |
+
def test_client_chat():
|
110 |
+
return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
|
111 |
+
|
112 |
+
|
113 |
+
def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens):
|
114 |
+
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output, max_new_tokens=max_new_tokens)
|
115 |
+
|
116 |
+
client = get_client(serialize=False)
|
117 |
+
|
118 |
+
res = client.predict(*tuple(args), api_name='/instruction')
|
119 |
+
args[-1] += [res[-1]]
|
120 |
+
|
121 |
+
res_dict = kwargs
|
122 |
+
res_dict['prompt'] = prompt
|
123 |
+
if not kwargs['stream_output']:
|
124 |
+
res = client.predict(*tuple(args), api_name='/instruction_bot')
|
125 |
+
res_dict['response'] = res[0][-1][1]
|
126 |
+
print(md_to_text(res_dict['response']))
|
127 |
+
return res_dict
|
128 |
+
else:
|
129 |
+
job = client.submit(*tuple(args), api_name='/instruction_bot')
|
130 |
+
res1 = ''
|
131 |
+
while not job.done():
|
132 |
+
outputs_list = job.communicator.job.outputs
|
133 |
+
if outputs_list:
|
134 |
+
res = job.communicator.job.outputs[-1]
|
135 |
+
res1 = res[0][-1][-1]
|
136 |
+
res1 = md_to_text(res1)
|
137 |
+
print(res1)
|
138 |
+
time.sleep(0.1)
|
139 |
+
print(job.outputs())
|
140 |
+
res_dict['response'] = res1
|
141 |
+
return res_dict
|
142 |
|
143 |
|
144 |
def md_to_text(md):
|
145 |
+
assert md is not None, "Markdown is None"
|
146 |
html = markdown.markdown(md)
|
147 |
soup = BeautifulSoup(html, features='html.parser')
|
148 |
return soup.get_text()
|
create_data.py
ADDED
@@ -0,0 +1,1818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset creation tools.
|
3 |
+
|
4 |
+
Keep to-level imports clean of non-trivial imports for specific tools,
|
5 |
+
because this file is imported for various purposes
|
6 |
+
"""
|
7 |
+
|
8 |
+
import ast
|
9 |
+
import concurrent.futures
|
10 |
+
import contextlib
|
11 |
+
import hashlib
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import shutil
|
15 |
+
import signal
|
16 |
+
import sys
|
17 |
+
import traceback
|
18 |
+
from concurrent.futures import ProcessPoolExecutor
|
19 |
+
|
20 |
+
import psutil
|
21 |
+
import pytest
|
22 |
+
import pandas as pd
|
23 |
+
import numpy as np
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
from utils import flatten_list
|
27 |
+
|
28 |
+
|
29 |
+
def parse_rst_file(filepath):
|
30 |
+
with open(filepath, 'r') as f:
|
31 |
+
input_data = f.read()
|
32 |
+
settings_overrides = {'initial_header_level': 2}
|
33 |
+
from docutils import core
|
34 |
+
document = core.publish_doctree(
|
35 |
+
source=input_data,
|
36 |
+
source_path=filepath,
|
37 |
+
settings_overrides=settings_overrides,
|
38 |
+
)
|
39 |
+
qa_pairs = []
|
40 |
+
current_section = None
|
41 |
+
current_question = ""
|
42 |
+
current_answer = ""
|
43 |
+
for node in document.traverse():
|
44 |
+
if node.__class__.__name__ == 'section':
|
45 |
+
current_section = ""
|
46 |
+
elif current_section is not None:
|
47 |
+
if node.__class__.__name__ == 'Text':
|
48 |
+
if node.astext()[-1] == "?":
|
49 |
+
if current_question:
|
50 |
+
qa_pairs.append((current_question, current_answer))
|
51 |
+
current_question = node.astext()
|
52 |
+
current_answer = ""
|
53 |
+
else:
|
54 |
+
current_answer += node.astext()
|
55 |
+
if current_answer:
|
56 |
+
qa_pairs.append((current_question, current_answer))
|
57 |
+
return {k: v for k, v in qa_pairs}
|
58 |
+
|
59 |
+
|
60 |
+
def test_scrape_dai_docs():
|
61 |
+
home = os.path.expanduser('~')
|
62 |
+
file = os.path.join(home, 'h2oai/docs/faq.rst')
|
63 |
+
qa_pairs = parse_rst_file(file)
|
64 |
+
prompt_type = 'human_bot'
|
65 |
+
from prompter import prompt_types
|
66 |
+
assert prompt_type in prompt_types
|
67 |
+
save_thing = [{"instruction": k, "output": v, 'prompt_type': prompt_type} for k, v in qa_pairs.items()]
|
68 |
+
output_file = "dai_faq.json"
|
69 |
+
with open(output_file, "wt") as f:
|
70 |
+
f.write(json.dumps(save_thing, indent=2))
|
71 |
+
|
72 |
+
|
73 |
+
def test_scrape_dai_docs_all():
|
74 |
+
"""
|
75 |
+
pytest create_data.py::test_scrape_dai_docs_all
|
76 |
+
"""
|
77 |
+
import glob
|
78 |
+
import nltk
|
79 |
+
nltk.download('punkt')
|
80 |
+
dd = {}
|
81 |
+
np.random.seed(1234)
|
82 |
+
home = os.path.expanduser('~')
|
83 |
+
files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst")))
|
84 |
+
np.random.shuffle(files)
|
85 |
+
val_count = int(0.05 * len(files))
|
86 |
+
train_files = files[val_count:]
|
87 |
+
valid_files = files[:val_count]
|
88 |
+
things = [
|
89 |
+
("dai_docs.train.json", train_files),
|
90 |
+
("dai_docs.valid.json", valid_files)
|
91 |
+
]
|
92 |
+
for LEN in [100, 200, 500]:
|
93 |
+
for output_file, ff in things:
|
94 |
+
if output_file not in dd:
|
95 |
+
dd[output_file] = []
|
96 |
+
for f in ff:
|
97 |
+
with open(f) as input:
|
98 |
+
blob = input.read()
|
99 |
+
blob = blob.replace("~~", "")
|
100 |
+
blob = blob.replace("==", "")
|
101 |
+
blob = blob.replace("''", "")
|
102 |
+
blob = blob.replace("--", "")
|
103 |
+
blob = blob.replace("**", "")
|
104 |
+
dd[output_file].extend(get_sentences(blob, length=LEN))
|
105 |
+
for output_file, _ in things:
|
106 |
+
save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in dd[output_file]]
|
107 |
+
with open(output_file, "wt") as f:
|
108 |
+
f.write(json.dumps(save_thing, indent=2))
|
109 |
+
|
110 |
+
|
111 |
+
def get_sentences(blob, length):
|
112 |
+
"""
|
113 |
+
break-up input text into sentences and then output list of sentences of about length in size
|
114 |
+
:param blob:
|
115 |
+
:param length:
|
116 |
+
:return:
|
117 |
+
"""
|
118 |
+
import nltk
|
119 |
+
nltk.download('punkt')
|
120 |
+
from nltk.tokenize import sent_tokenize
|
121 |
+
sentences = sent_tokenize(blob)
|
122 |
+
my_sentences = []
|
123 |
+
my_string = ""
|
124 |
+
for sentence in sentences:
|
125 |
+
if len(my_string) + len(sentence) <= length:
|
126 |
+
if my_string:
|
127 |
+
my_string += " " + sentence
|
128 |
+
else:
|
129 |
+
my_string = sentence
|
130 |
+
else:
|
131 |
+
my_sentences.append(my_string)
|
132 |
+
my_string = ""
|
133 |
+
return my_sentences or [my_string]
|
134 |
+
|
135 |
+
|
136 |
+
def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
|
137 |
+
"""
|
138 |
+
Only supported if have access to source code or HF token for HF spaces and from_hf=True
|
139 |
+
:param path:
|
140 |
+
:param dst:
|
141 |
+
:param from_hf:
|
142 |
+
:return:
|
143 |
+
"""
|
144 |
+
|
145 |
+
home = os.path.expanduser('~')
|
146 |
+
|
147 |
+
if from_hf:
|
148 |
+
# assumes
|
149 |
+
from huggingface_hub import hf_hub_download
|
150 |
+
# True for case when locally already logged in with correct token, so don't have to set key
|
151 |
+
token = os.getenv('HUGGINGFACE_API_TOKEN', True)
|
152 |
+
path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.zip', token=token, repo_type='dataset')
|
153 |
+
path = 'h2oai'
|
154 |
+
import zipfile
|
155 |
+
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
|
156 |
+
zip_ref.extractall(path)
|
157 |
+
path = os.path.join(path, 'docs/**/*')
|
158 |
+
|
159 |
+
if path is None:
|
160 |
+
if os.path.isdir(os.path.join(home, 'h2oai')):
|
161 |
+
path = os.path.join(home, "h2oai/docs/**/*")
|
162 |
+
else:
|
163 |
+
assert os.path.isdir(os.path.join(home, 'h2oai.superclean')), '%s does not exist' % path
|
164 |
+
path = os.path.join(home, "h2oai.superclean/docs/**/*")
|
165 |
+
import glob
|
166 |
+
files = list(glob.glob(path, recursive=True))
|
167 |
+
|
168 |
+
# pandoc can't find include files
|
169 |
+
|
170 |
+
remove(dst)
|
171 |
+
os.makedirs(dst)
|
172 |
+
|
173 |
+
# copy full tree, for absolute paths in rst
|
174 |
+
for fil in files:
|
175 |
+
if os.path.isfile(fil):
|
176 |
+
shutil.copy(fil, dst)
|
177 |
+
|
178 |
+
# hack for relative path
|
179 |
+
scorers_dir = os.path.join(dst, 'scorers')
|
180 |
+
makedirs(scorers_dir)
|
181 |
+
for fil in glob.glob(os.path.join(dst, '*.frag')):
|
182 |
+
shutil.copy(fil, scorers_dir)
|
183 |
+
|
184 |
+
return dst
|
185 |
+
|
186 |
+
|
187 |
+
def rst_to_outputs(files, min_len=30, max_len=2048//2 - 30):
|
188 |
+
# account for sequence length (context window) including prompt and input and output
|
189 |
+
|
190 |
+
# os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
|
191 |
+
import pypandoc
|
192 |
+
basedir = os.path.abspath(os.getcwd())
|
193 |
+
|
194 |
+
outputs = []
|
195 |
+
for fil in files:
|
196 |
+
os.chdir(basedir)
|
197 |
+
os.chdir(os.path.dirname(fil))
|
198 |
+
fil = os.path.basename(fil)
|
199 |
+
print("Processing %s" % fil, flush=True)
|
200 |
+
# out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x,
|
201 |
+
# context, csljson, docbook, docbook4, docbook5, docx, dokuwiki,
|
202 |
+
# dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml,
|
203 |
+
# ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira,
|
204 |
+
# json, latex, man,
|
205 |
+
# markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict,
|
206 |
+
# mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx,
|
207 |
+
# revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki
|
208 |
+
out_format = 'plain'
|
209 |
+
# avoid extra new lines injected into text
|
210 |
+
extra_args = ['--wrap=preserve', '--resource path="%s" % dst']
|
211 |
+
|
212 |
+
plain_list = []
|
213 |
+
try:
|
214 |
+
# valid for expert settings
|
215 |
+
input_rst = pypandoc.convert_file(fil, 'rst')
|
216 |
+
input_list = input_rst.split('\n``')
|
217 |
+
for input_subrst in input_list:
|
218 |
+
input_plain = pypandoc.convert_text(input_subrst, format='rst', to='plain')
|
219 |
+
plain_list.append([input_plain, fil])
|
220 |
+
except Exception as e:
|
221 |
+
print("file exception: %s %s" % (fil, str(e)), flush=True)
|
222 |
+
|
223 |
+
if not plain_list:
|
224 |
+
# if failed to process as pieces of rst, then
|
225 |
+
output = pypandoc.convert_file(fil, out_format, extra_args=extra_args, format='rst')
|
226 |
+
outputs1 = get_sentences(output, length=max_len)
|
227 |
+
for oi, output in enumerate(outputs1):
|
228 |
+
output = output.replace('\n\n', '\n')
|
229 |
+
plain_list.append([output, fil])
|
230 |
+
outputs.extend(plain_list)
|
231 |
+
|
232 |
+
# report:
|
233 |
+
# [print(len(x)) for x in outputs]
|
234 |
+
|
235 |
+
# deal with blocks longer than context size (sequence length) of 2048
|
236 |
+
new_outputs = []
|
237 |
+
num_truncated = 0
|
238 |
+
num_orig = len(outputs)
|
239 |
+
for output, fil in outputs:
|
240 |
+
if len(output) < max_len:
|
241 |
+
new_outputs.append([output, fil])
|
242 |
+
continue
|
243 |
+
outputs1 = get_sentences(output, length=max_len)
|
244 |
+
for oi, output1 in enumerate(outputs1):
|
245 |
+
output1 = output1.replace('\n\n', '\n')
|
246 |
+
new_outputs.append([output1, fil])
|
247 |
+
num_truncated += 1
|
248 |
+
print('num_orig: %s num_truncated: %s' % (num_orig, num_truncated), flush=True)
|
249 |
+
|
250 |
+
new_outputs = [[k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len]
|
251 |
+
|
252 |
+
return new_outputs
|
253 |
+
|
254 |
+
|
255 |
+
def test_scrape_dai_docs_all_pandoc():
|
256 |
+
"""
|
257 |
+
pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc
|
258 |
+
:return:
|
259 |
+
"""
|
260 |
+
|
261 |
+
dst = setup_dai_docs()
|
262 |
+
|
263 |
+
import glob
|
264 |
+
files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
|
265 |
+
|
266 |
+
basedir = os.path.abspath(os.getcwd())
|
267 |
+
new_outputs = rst_to_outputs(files)
|
268 |
+
os.chdir(basedir)
|
269 |
+
|
270 |
+
remove(dst)
|
271 |
+
save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in new_outputs]
|
272 |
+
output_file = "dai_docs.train_cleaned.json"
|
273 |
+
with open(output_file, "wt") as f:
|
274 |
+
f.write(json.dumps(save_thing, indent=2))
|
275 |
+
|
276 |
+
|
277 |
+
def remove(path: str):
|
278 |
+
try:
|
279 |
+
if path is not None and os.path.exists(path):
|
280 |
+
if os.path.isdir(path):
|
281 |
+
shutil_rmtree(path, ignore_errors=True)
|
282 |
+
else:
|
283 |
+
with contextlib.suppress(FileNotFoundError):
|
284 |
+
os.remove(path)
|
285 |
+
except:
|
286 |
+
pass
|
287 |
+
|
288 |
+
|
289 |
+
def shutil_rmtree(*args, **kwargs):
|
290 |
+
return shutil.rmtree(*args, **kwargs)
|
291 |
+
|
292 |
+
|
293 |
+
def test_config_to_json():
|
294 |
+
"""
|
295 |
+
Needs to run from Driverless AI source directory.
|
296 |
+
E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/
|
297 |
+
:return:
|
298 |
+
"""
|
299 |
+
try:
|
300 |
+
# Arrange
|
301 |
+
import json
|
302 |
+
from h2oaicore.systemutils import config
|
303 |
+
toml_list = []
|
304 |
+
for k, v in config.get_meta_dict().items():
|
305 |
+
title = (v.title + ": ") if v.title else ''
|
306 |
+
comment = v.comment or ''
|
307 |
+
if not (title or comment):
|
308 |
+
continue
|
309 |
+
toml_list.extend(
|
310 |
+
[
|
311 |
+
{
|
312 |
+
'prompt_type': 'plain',
|
313 |
+
'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace("\n", ""),
|
314 |
+
},
|
315 |
+
{
|
316 |
+
'prompt_type': 'plain',
|
317 |
+
'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace("\n", ""),
|
318 |
+
},
|
319 |
+
{
|
320 |
+
'prompt_type': 'plain',
|
321 |
+
'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace("\n", ""),
|
322 |
+
} if title and comment else None,
|
323 |
+
{
|
324 |
+
'prompt_type': 'human_bot',
|
325 |
+
'instruction': f'Explain the following expert setting for Driverless AI',
|
326 |
+
'input': f"{k}",
|
327 |
+
'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
|
328 |
+
},
|
329 |
+
{
|
330 |
+
'prompt_type': 'human_bot',
|
331 |
+
'instruction': f'Explain the following expert setting for Driverless AI',
|
332 |
+
'input': f"{k}",
|
333 |
+
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
334 |
+
},
|
335 |
+
{
|
336 |
+
'prompt_type': 'human_bot',
|
337 |
+
'instruction': f'Explain the following expert setting for Driverless AI',
|
338 |
+
'input': f"{k.replace('_', ' ')}",
|
339 |
+
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
340 |
+
},
|
341 |
+
{
|
342 |
+
'prompt_type': 'human_bot',
|
343 |
+
'instruction': f'Explain the following expert setting for Driverless AI',
|
344 |
+
'input': f"{title}",
|
345 |
+
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
346 |
+
},
|
347 |
+
{
|
348 |
+
'prompt_type': 'human_bot',
|
349 |
+
'instruction': f'Provide a short explanation of the expert setting {k}',
|
350 |
+
'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
|
351 |
+
},
|
352 |
+
{
|
353 |
+
'prompt_type': 'human_bot',
|
354 |
+
'instruction': f'Provide a detailed explanation of the expert setting {k}',
|
355 |
+
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
356 |
+
},
|
357 |
+
]
|
358 |
+
)
|
359 |
+
toml_list = [x for x in toml_list if x]
|
360 |
+
with open("config.json", "wt") as f:
|
361 |
+
f.write(json.dumps(toml_list, indent=2))
|
362 |
+
except Exception as e:
|
363 |
+
print("Exception: %s" % str(e), flush=True)
|
364 |
+
|
365 |
+
|
366 |
+
def copy_tree(src, dst, follow_symlink=False):
|
367 |
+
makedirs(dst, exist_ok=True)
|
368 |
+
for (path, dirs, files) in os.walk(src, followlinks=follow_symlink):
|
369 |
+
new_path = path.replace(src, dst)
|
370 |
+
makedirs(new_path, exist_ok=True)
|
371 |
+
for file in files:
|
372 |
+
filename = os.path.join(path, file)
|
373 |
+
new_filename = os.path.join(new_path, file)
|
374 |
+
# print("%s -> %s" % (filename, new_filename))
|
375 |
+
try:
|
376 |
+
atomic_copy(filename, new_filename)
|
377 |
+
except FileNotFoundError:
|
378 |
+
pass
|
379 |
+
|
380 |
+
|
381 |
+
def atomic_move(src, dst):
|
382 |
+
try:
|
383 |
+
shutil.move(src, dst)
|
384 |
+
except (shutil.Error, FileExistsError):
|
385 |
+
pass
|
386 |
+
remove(src)
|
387 |
+
|
388 |
+
|
389 |
+
def atomic_copy(src=None, dst=None, with_permissions=True):
|
390 |
+
if os.path.isfile(dst):
|
391 |
+
return
|
392 |
+
import uuid
|
393 |
+
my_uuid = uuid.uuid4()
|
394 |
+
dst_tmp = dst + str(my_uuid)
|
395 |
+
makedirs(os.path.dirname(dst), exist_ok=True)
|
396 |
+
if with_permissions:
|
397 |
+
shutil.copy(src, dst_tmp)
|
398 |
+
else:
|
399 |
+
shutil.copyfile(src, dst_tmp)
|
400 |
+
atomic_move(dst_tmp, dst)
|
401 |
+
remove(dst_tmp)
|
402 |
+
|
403 |
+
|
404 |
+
def makedirs(path, exist_ok=True):
|
405 |
+
"""
|
406 |
+
Avoid some inefficiency in os.makedirs()
|
407 |
+
:param path:
|
408 |
+
:param exist_ok:
|
409 |
+
:return:
|
410 |
+
"""
|
411 |
+
if os.path.isdir(path) and os.path.exists(path):
|
412 |
+
assert exist_ok, "Path already exists"
|
413 |
+
return path
|
414 |
+
os.makedirs(path, exist_ok=exist_ok)
|
415 |
+
|
416 |
+
|
417 |
+
## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json
|
418 |
+
## Turn into simple instruct prompt type. No context/previous conversations.
|
419 |
+
def test_prep_instruct_vicuna():
|
420 |
+
from datasets import load_dataset
|
421 |
+
filename = 'ShareGPT_unfiltered_cleaned_split.json'
|
422 |
+
if not os.path.exists(filename):
|
423 |
+
os.system('wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
|
424 |
+
data = load_dataset("json", data_files={"train": filename})["train"]
|
425 |
+
training_rows = []
|
426 |
+
for i in range(data.num_rows):
|
427 |
+
conversations = data[i]['conversations']
|
428 |
+
assert isinstance(conversations, list), conversations
|
429 |
+
convo = ""
|
430 |
+
for j, conv in enumerate(conversations):
|
431 |
+
# Get ready for generate.py prompt_type=human_bot
|
432 |
+
# But train with prompt_type=plain
|
433 |
+
if conv['from'] == 'human':
|
434 |
+
FROM = '<human>: '
|
435 |
+
elif conv['from'] == 'gpt':
|
436 |
+
FROM = '<bot>: '
|
437 |
+
convo += f"{FROM}" + conv['value'] + "\n"
|
438 |
+
if convo:
|
439 |
+
training_rows.append(dict(input=convo))
|
440 |
+
with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
|
441 |
+
f.write(json.dumps(training_rows, indent=2))
|
442 |
+
|
443 |
+
POSTFIX = ".generate_human_bot.train_plain.json"
|
444 |
+
|
445 |
+
# https://bair.berkeley.edu/blog/2023/04/03/koala/
|
446 |
+
OIG_DATASETS = [
|
447 |
+
"unified_chip2.jsonl",
|
448 |
+
"unified_grade_school_math_instructions.jsonl",
|
449 |
+
"unified_poetry_2_song.jsonl",
|
450 |
+
"unified_plot_screenplay_books_dialog.jsonl",
|
451 |
+
]
|
452 |
+
|
453 |
+
# hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4
|
454 |
+
ALL_OIG_DATASETS = ['unified_abstract_infill.jsonl',
|
455 |
+
'unified_basic.jsonl',
|
456 |
+
'unified_canadian_parliament.jsonl',
|
457 |
+
'unified_chip2.jsonl',
|
458 |
+
'unified_conv_finqa.jsonl',
|
459 |
+
'unified_cuad.jsonl',
|
460 |
+
'unified_essays.jsonl',
|
461 |
+
'unified_flan.jsonl.gz',
|
462 |
+
'unified_grade_school_math_instructions.jsonl',
|
463 |
+
'unified_hc3_human.jsonl',
|
464 |
+
'unified_image_prompts_instructions.jsonl',
|
465 |
+
'unified_joke_explanations.jsonl',
|
466 |
+
'unified_mathqa_flanv2_kojma_cot.jsonl',
|
467 |
+
'unified_merged_code_xp3.jsonl',
|
468 |
+
'unified_multi_news.jsonl',
|
469 |
+
'unified_multi_sum.jsonl',
|
470 |
+
'unified_ni.jsonl.gz',
|
471 |
+
'unified_nq.jsonl',
|
472 |
+
'unified_openai_summarize_tldr.jsonl',
|
473 |
+
'unified_oscar_en_sample_dialog.jsonl',
|
474 |
+
'unified_p3.jsonl.gz',
|
475 |
+
'unified_plot_screenplay_books_dialog.jsonl',
|
476 |
+
'unified_poetry_2_song.jsonl',
|
477 |
+
'unified_poetry_instructions.jsonl',
|
478 |
+
'unified_rallio_safety_and_prosocial.jsonl',
|
479 |
+
'unified_rallio_soda_upgraded_2048.jsonl',
|
480 |
+
'unified_soda_dialog.jsonl',
|
481 |
+
'unified_sqlv1.jsonl',
|
482 |
+
'unified_sqlv2.jsonl',
|
483 |
+
'unified_squad_v2.jsonl',
|
484 |
+
'unified_squad_v2_more_neg.jsonl',
|
485 |
+
'unified_ul2_plus_oscar_en_sample_dialog.jsonl',
|
486 |
+
'unified_unifiedskg_instructions.jsonl',
|
487 |
+
'unified_unnatural_instructions.jsonl',
|
488 |
+
'unified_xp3_sample.jsonl']
|
489 |
+
|
490 |
+
useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
|
491 |
+
'unified_chip2.jsonl.parquet',
|
492 |
+
'unified_cuad.jsonl.parquet',
|
493 |
+
'unified_essays.jsonl.parquet',
|
494 |
+
'unified_flan.jsonl.gz.parquet',
|
495 |
+
'unified_grade_school_math_instructions.jsonl.parquet',
|
496 |
+
'unified_hc3_human.jsonl.parquet',
|
497 |
+
'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
|
498 |
+
'unified_merged_code_xp3.jsonl.parquet',
|
499 |
+
'unified_multi_news.jsonl.parquet',
|
500 |
+
#'unified_multi_sum.jsonl.parquet'
|
501 |
+
'unified_ni.jsonl.gz.parquet',
|
502 |
+
'unified_openai_summarize_tldr.jsonl.parquet',
|
503 |
+
#'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
|
504 |
+
'unified_plot_screenplay_books_dialog.jsonl.parquet',
|
505 |
+
'unified_soda_dialog.jsonl.parquet',
|
506 |
+
'unified_unnatural_instructions.jsonl.parquet',
|
507 |
+
]
|
508 |
+
|
509 |
+
|
510 |
+
@pytest.mark.parametrize("filename", OIG_DATASETS)
|
511 |
+
def test_get_small_sample_oig_data(filename):
|
512 |
+
if not os.path.exists(filename):
|
513 |
+
os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
|
514 |
+
import json
|
515 |
+
rows = []
|
516 |
+
with open(filename, "r") as f:
|
517 |
+
for line in f.readlines():
|
518 |
+
row = json.loads(line)
|
519 |
+
rows.append(dict(input=row["text"]))
|
520 |
+
with open(filename + POSTFIX, "w") as f:
|
521 |
+
f.write(json.dumps(rows, indent=2))
|
522 |
+
|
523 |
+
|
524 |
+
@pytest.mark.parametrize("filename", ALL_OIG_DATASETS)
|
525 |
+
def test_download_useful_data_as_parquet(filename):
|
526 |
+
dest_file = filename + '.parquet'
|
527 |
+
if dest_file not in useful_oig_files:
|
528 |
+
pytest.skip('file declared not useful')
|
529 |
+
if not os.path.exists(filename):
|
530 |
+
os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
|
531 |
+
if not os.path.exists(dest_file):
|
532 |
+
df = pd.read_json(path_or_buf=filename, lines=True)
|
533 |
+
df.to_parquet(dest_file, index=False)
|
534 |
+
|
535 |
+
|
536 |
+
def test_merge_shuffle_small_sample_oig_data():
|
537 |
+
np.random.seed(1234)
|
538 |
+
rows = []
|
539 |
+
for filename in OIG_DATASETS:
|
540 |
+
with open(filename + POSTFIX, "r") as f:
|
541 |
+
rows.extend(json.loads(f.read()))
|
542 |
+
np.random.shuffle(rows)
|
543 |
+
with open("merged_shuffled_OIG_%s.json" % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], "w") as f:
|
544 |
+
f.write(json.dumps(rows, indent=2))
|
545 |
+
|
546 |
+
|
547 |
+
def test_join_jsons():
|
548 |
+
files = ['config.json'] * 1 + \
|
549 |
+
['dai_docs.train_cleaned.json'] * 2 + \
|
550 |
+
['dai_faq.json'] * 3
|
551 |
+
print(files)
|
552 |
+
lst = []
|
553 |
+
[lst.extend(json.load(open(fil, 'rt'))) for fil in files]
|
554 |
+
print(len(lst))
|
555 |
+
json.dump(lst, open("merged.json", "wt"), indent=2)
|
556 |
+
|
557 |
+
|
558 |
+
@pytest.mark.parametrize("filename", ['Anthropic/hh-rlhf'])
|
559 |
+
def test_make_rlhf_good_data(filename):
|
560 |
+
from datasets import load_dataset
|
561 |
+
rows = load_dataset(filename)["train"]["chosen"]
|
562 |
+
new_rows = []
|
563 |
+
for row in rows:
|
564 |
+
if row[:2] == "\n\n":
|
565 |
+
row = row[2:]
|
566 |
+
row = row.replace("Human: ", "<human>: ")
|
567 |
+
row = row.replace("Assistant: ", "<bot>: ")
|
568 |
+
new_rows.append(dict(input=row))
|
569 |
+
with open(filename.replace("/", "_") + POSTFIX, "w") as f:
|
570 |
+
f.write(json.dumps(new_rows, indent=2))
|
571 |
+
|
572 |
+
|
573 |
+
|
574 |
+
def test_show_prompts():
|
575 |
+
files = ['config.json'] * 1 + \
|
576 |
+
['dai_docs.train_cleaned.json'] * 1 + \
|
577 |
+
['dai_faq.json'] * 1
|
578 |
+
file_points = [json.load(open(fil, 'rt')) for fil in files]
|
579 |
+
from prompter import generate_prompt
|
580 |
+
for data_points in file_points:
|
581 |
+
for data_point in data_points:
|
582 |
+
print(generate_prompt(data_point, 'plain', False, False)[0])
|
583 |
+
|
584 |
+
|
585 |
+
def test_get_open_datasets():
|
586 |
+
# HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter
|
587 |
+
open_tags = ['license:Apache License 2.0',
|
588 |
+
'license:mit',
|
589 |
+
'license:apache',
|
590 |
+
'license:apache2',
|
591 |
+
'license:apache-2.0',
|
592 |
+
'license:bsd',
|
593 |
+
'license:bsd-2-clause',
|
594 |
+
'license:bsd-3-clause',
|
595 |
+
'license:bsd-3-clause-clear',
|
596 |
+
'license:lgpl-2.1',
|
597 |
+
'license:lgpl-3.0',
|
598 |
+
'license:lgpl-lr',
|
599 |
+
'license:lgpl',
|
600 |
+
'license:openrail++',
|
601 |
+
'license:openrail',
|
602 |
+
'license:bigscience-bloom-rail-1.0',
|
603 |
+
#'license:agpl-3.0',
|
604 |
+
'license:other',
|
605 |
+
'license:unknown',
|
606 |
+
# 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
|
607 |
+
# Attribution required:
|
608 |
+
'license:odc-by',
|
609 |
+
'license:cc-by-4.0',
|
610 |
+
'license:cc-by-3.0',
|
611 |
+
'license:cc-by-2.0',
|
612 |
+
'license:cc-by-2.5',
|
613 |
+
#'license:cc-by-sa-4.0', # would require same license
|
614 |
+
'license:odbl',
|
615 |
+
'license:pddl',
|
616 |
+
'license:ms-pl',
|
617 |
+
'license:zlib',
|
618 |
+
]
|
619 |
+
# bad license: cc-by-nc-4.0
|
620 |
+
|
621 |
+
from huggingface_hub import list_datasets
|
622 |
+
datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
|
623 |
+
datasets += [x for x in list_datasets(author='openai')]
|
624 |
+
# check all:
|
625 |
+
all_license_tags = set(flatten_list([[y for y in x.tags if 'license' in y] for x in datasets]))
|
626 |
+
print(len(all_license_tags))
|
627 |
+
open_datasets = [x for x in datasets if any([y in x.tags for y in open_tags]) or 'license:' not in str(x.tags)]
|
628 |
+
print('open_datasets', len(open_datasets))
|
629 |
+
all_task_tags = set(flatten_list([[y for y in x.tags if 'task' in y] for x in open_datasets]))
|
630 |
+
print('all_task_tags', len(all_task_tags))
|
631 |
+
excluded_tags = ['image', 'hate', 'tabular', 'table-', 'classification', 'retrieval',
|
632 |
+
'translation', 'identification', 'object', 'mask', 'to-text',
|
633 |
+
'face-detection', 'audio', 'voice', 'reinforcement', 'depth-est',
|
634 |
+
'forecasting', 'parsing', 'visual', 'speech', 'multiple-choice',
|
635 |
+
'slot-filling', 'irds/argsme', '-scoring', 'other', 'graph-ml',
|
636 |
+
'feature-extraction', 'keyword-spotting',
|
637 |
+
'coreference-resolution', 'segmentation',
|
638 |
+
'word-sense-disambiguation',
|
639 |
+
'lemmatization']
|
640 |
+
task_tags = [x.replace('task_categories:', '').replace('task_ids:', '')
|
641 |
+
for x in all_task_tags if not any([y in x for y in
|
642 |
+
excluded_tags])]
|
643 |
+
print('task_tags', len(task_tags))
|
644 |
+
# str(x.tags) to catch any pattern match to anything in list
|
645 |
+
open_tasked_datasets = [x for x in open_datasets if
|
646 |
+
any([y in str([x for x in x.tags if 'task' in x]) for y in task_tags]) and
|
647 |
+
not any([y in str([x for x in x.tags if 'task' in x]) for y in excluded_tags]) or
|
648 |
+
'task_categories' not in str(x.tags) and 'task_ids' not in str(x.tags)]
|
649 |
+
open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled]
|
650 |
+
open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated]
|
651 |
+
open_tasked_datasets = [x for x in open_tasked_datasets if not x.private]
|
652 |
+
print('open_tasked_datasets', len(open_tasked_datasets))
|
653 |
+
sizes = list(set(flatten_list([[(y, x.id) for y in x.tags if 'size' in y] for x in open_tasked_datasets])))
|
654 |
+
languages = list(set(flatten_list([[(y, x.id) for y in x.tags if 'language:' in y] for x in open_tasked_datasets])))
|
655 |
+
open_english_tasked_datasets = [x for x in open_tasked_datasets if
|
656 |
+
'language:' not in str(x.tags) or
|
657 |
+
'language:en' in str(x.tags)]
|
658 |
+
small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
|
659 |
+
'n<1K' in str(x.tags) or
|
660 |
+
'1K<n<10K' in str(x.tags) or
|
661 |
+
'1K0<n<100K' in str(x.tags) or
|
662 |
+
'100K<n<1M' in str(x.tags) or
|
663 |
+
'size_category' not in str(x.tags)
|
664 |
+
]
|
665 |
+
# 'aeslc' : email_body, subject -> summarization?
|
666 |
+
# load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
|
667 |
+
ids = [x.id for x in small_open_english_tasked_datasets]
|
668 |
+
|
669 |
+
# sanity checks
|
670 |
+
# https://bair.berkeley.edu/blog/2023/04/03/koala/
|
671 |
+
assert 'alespalla/chatbot_instruction_prompts' in ids
|
672 |
+
assert 'laion/OIG' in ids
|
673 |
+
assert 'openai/webgpt_comparisons' in ids
|
674 |
+
assert 'openai/summarize_from_feedback' in ids
|
675 |
+
assert 'Anthropic/hh-rlhf' in ids
|
676 |
+
|
677 |
+
# useful but not allowed for commercial purposes:
|
678 |
+
# https://huggingface.co/datasets/squad
|
679 |
+
|
680 |
+
print('open_english_tasked_datasets: ', ids, flush=True)
|
681 |
+
|
682 |
+
exclude_ids = ['allenai/nllb', # translation only
|
683 |
+
'hf-internal-testing/fixtures_image_utils', # testing
|
684 |
+
'allenai/c4', # search-url
|
685 |
+
'agemagician/uniref50', # unknown
|
686 |
+
'huggingface-course/documentation-images', # images
|
687 |
+
'smilegate-ai/kor_unsmile', # korean
|
688 |
+
'MohamedRashad/ChatGPT-prompts', # ChatGPT/LearnGPT/https://www.emergentmind.com/
|
689 |
+
'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
|
690 |
+
'Jeska/vaccinchat', # not useful
|
691 |
+
'alespalla/chatbot_instruction_prompts', # mixes alpaca
|
692 |
+
'allenai/prosocial-dialog', # already exlucded, but wrongly in other datasets that say more permissive license
|
693 |
+
'AlekseyKorshuk/persona-chat', # low quality
|
694 |
+
'bavard/personachat_truecased', # low quality
|
695 |
+
'adamlin/daily_dialog', # medium quality conversations
|
696 |
+
'adamlin/FewShotWoz', # low quality
|
697 |
+
'benjaminbeilharz/better_daily_dialog', # low quality
|
698 |
+
'benjaminbeilharz/daily_dialog_w_turn_templates', # low
|
699 |
+
'benjaminbeilharz/empathetic_dialogues_for_lm', # low
|
700 |
+
'GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915', # NA
|
701 |
+
'ia-bentebib/conv_ai_2_fr', # low fr
|
702 |
+
'ia-bentebib/daily_dialog_fr', # low fr
|
703 |
+
'ia-bentebib/dialog_re_fr', # low fr
|
704 |
+
'ia-bentebib/empathetic_dialogues_fr', # low fr
|
705 |
+
'roskoN/dailydialog', # low
|
706 |
+
'VadorMazer/skyrimdialogstest', # low
|
707 |
+
'bigbio/med_qa', # med specific Q/A
|
708 |
+
'biu-nlp/qa_srl2018', # low quality Q/A
|
709 |
+
'biu-nlp/qa_discourse', # low quality Q/A
|
710 |
+
'iarfmoose/qa_evaluator', # low quality Q/A
|
711 |
+
'jeopardy', # low quality Q/A -- no reasoning
|
712 |
+
'narrativeqa', # low quality Q/A
|
713 |
+
'nomic-ai/gpt4all_prompt_generations', # bad license
|
714 |
+
'nomic-ai/gpt4all_prompt_generations_with_p3', # bad license
|
715 |
+
'HuggingFaceH4/alpaca', # bad license
|
716 |
+
'tatsu-lab/alpaca', # ToS breaking
|
717 |
+
'yahma/alpaca-cleaned', # ToS breaking
|
718 |
+
'Hello-SimpleAI/HC3', # bad license
|
719 |
+
'glue', # no reasoning QA
|
720 |
+
'sahil2801/CodeAlpaca-20k', # bad license
|
721 |
+
'Short-Answer-Feedback/saf_communication_networks_english', # long Q, medium A
|
722 |
+
]
|
723 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if x.id not in exclude_ids]
|
724 |
+
# some ids clearly speech related
|
725 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
|
726 |
+
# HF testing
|
727 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'hf-internal-testing' not in x.id]
|
728 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
729 |
+
'chinese' not in x.id]
|
730 |
+
|
731 |
+
sorted_small_open_english_tasked_datasets = sorted([(x.downloads, x) for x in small_open_english_tasked_datasets],
|
732 |
+
key=lambda x: x[0], reverse=True)
|
733 |
+
|
734 |
+
# NOTES:
|
735 |
+
# Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log
|
736 |
+
# See what needs config passed and add:
|
737 |
+
# grep 'load_dataset(' getdata9.log|grep -v data_id|less -S
|
738 |
+
# grep "pip install" getdata9.log
|
739 |
+
# NOTE: Some datasets have default config, but others are there. Don't know how to access them.
|
740 |
+
|
741 |
+
|
742 |
+
"""
|
743 |
+
https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
|
744 |
+
https://github.com/mahnazkoupaee/WikiHow-Dataset
|
745 |
+
https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
|
746 |
+
https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
|
747 |
+
"""
|
748 |
+
|
749 |
+
"""
|
750 |
+
# some ambiguous or non-commercial datasets
|
751 |
+
https://github.com/PhoebusSi/alpaca-CoT
|
752 |
+
"""
|
753 |
+
|
754 |
+
timeout = 3 * 60
|
755 |
+
# laion/OIG takes longer
|
756 |
+
for num_downloads, dataset in sorted_small_open_english_tasked_datasets:
|
757 |
+
data_id = dataset.id
|
758 |
+
func = do_one
|
759 |
+
args = (data_id, num_downloads)
|
760 |
+
kwargs = {}
|
761 |
+
with ProcessPoolExecutor(max_workers=1) as executor:
|
762 |
+
future = executor.submit(func, *args, **kwargs)
|
763 |
+
try:
|
764 |
+
future.result(timeout=timeout)
|
765 |
+
except concurrent.futures.TimeoutError:
|
766 |
+
print("\n\ndata_id %s timeout\n\n" % data_id, flush=True)
|
767 |
+
for child in psutil.Process(os.getpid()).children(recursive=True):
|
768 |
+
os.kill(child.pid, signal.SIGINT)
|
769 |
+
os.kill(child.pid, signal.SIGTERM)
|
770 |
+
os.kill(child.pid, signal.SIGKILL)
|
771 |
+
|
772 |
+
|
773 |
+
def do_one(data_id, num_downloads):
|
774 |
+
from datasets import load_dataset
|
775 |
+
out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
|
776 |
+
if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024**3:
|
777 |
+
return
|
778 |
+
try:
|
779 |
+
print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
|
780 |
+
avail_list = None
|
781 |
+
try:
|
782 |
+
data = load_dataset(data_id, 'foobar')
|
783 |
+
except Exception as e:
|
784 |
+
if 'Available: ' in str(e):
|
785 |
+
avail_list = ast.literal_eval(str(e).split('Available:')[1].strip())
|
786 |
+
else:
|
787 |
+
avail_list = None
|
788 |
+
if avail_list is None:
|
789 |
+
avail_list = [None]
|
790 |
+
print("%s avail_list: %s" % (data_id, avail_list), flush=True)
|
791 |
+
|
792 |
+
for name in avail_list:
|
793 |
+
out_file = "data_%s_%s.parquet" % (str(data_id.replace('/', '_')), str(name))
|
794 |
+
if os.path.isfile(out_file):
|
795 |
+
continue
|
796 |
+
data = load_dataset(data_id, name)
|
797 |
+
column_names_dict = data.column_names
|
798 |
+
column_names = column_names_dict[list(column_names_dict.keys())[0]]
|
799 |
+
print("Processing data_id %s num_downloads: %s columns: %s" % (data_id, num_downloads, column_names),
|
800 |
+
flush=True)
|
801 |
+
data_dict = data.data
|
802 |
+
col_dict = data.num_columns
|
803 |
+
first_col = list(col_dict.keys())[0]
|
804 |
+
if 'train' in data_dict:
|
805 |
+
df = data['train'].to_pandas()
|
806 |
+
else:
|
807 |
+
df = data[first_col].to_pandas()
|
808 |
+
# csv has issues with escaping chars, even for datasets I know I want
|
809 |
+
df.to_parquet(out_file, index=False)
|
810 |
+
except Exception as e:
|
811 |
+
t, v, tb = sys.exc_info()
|
812 |
+
ex = ''.join(traceback.format_exception(t, v, tb))
|
813 |
+
print("Exception: %s %s" % (data_id, ex), flush=True)
|
814 |
+
|
815 |
+
|
816 |
+
def test_otherlic():
|
817 |
+
from huggingface_hub import list_datasets
|
818 |
+
lic = ['license:odc-by',
|
819 |
+
'license:cc-by-4.0',
|
820 |
+
'license:cc-by-3.0',
|
821 |
+
'license:cc-by-2.0',
|
822 |
+
'license:cc-by-2.5',
|
823 |
+
'license:cc-by-sa-4.0',
|
824 |
+
'license:odbl',
|
825 |
+
'license:pddl',
|
826 |
+
'license:ms-pl',
|
827 |
+
'license:zlib',
|
828 |
+
]
|
829 |
+
datasets = flatten_list([[x for x in list_datasets(filter=y) if 'translation' not in str(x.tags)] for y in lic])
|
830 |
+
print(len(datasets))
|
831 |
+
|
832 |
+
|
833 |
+
# These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile
|
834 |
+
# grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog
|
835 |
+
useful = ['Dahoas/instruct-human-assistant-prompt',
|
836 |
+
'Dahoas/first-instruct-human-assistant-prompt',
|
837 |
+
'knkarthick/dialogsum', # summary of conversation
|
838 |
+
'McGill-NLP/FaithDial', # medium quality
|
839 |
+
'Zaid/quac_expanded', # medium quality context + QA
|
840 |
+
'0-hero/OIG-small-chip2', # medium
|
841 |
+
'alistvt/coqa-flat', # QA medium
|
842 |
+
'AnonymousSub/MedQuAD_47441_Question_Answer_Pairs', # QA medium
|
843 |
+
'Anthropic/hh-rlhf', # high quality # similar to Dahoas/full-hh-rlhf
|
844 |
+
'arjunth2001/online_privacy_qna', # good quality QA
|
845 |
+
'Dahoas/instruct_helpful_preferences', # medium quality instruct
|
846 |
+
'Dahoas/rl-prompt-dataset', # medium chat
|
847 |
+
'Dahoas/rm-static', # medium chat
|
848 |
+
'Dahoas/static-hh', # medium chat # HuggingFaceH4/self_instruct
|
849 |
+
'Dahoas/synthetic-instruct-gptj-pairwise', # medium chat
|
850 |
+
'eli5', # QA if prompt ELI5
|
851 |
+
'gsm8k', # QA (various)
|
852 |
+
'guanaco/guanaco', # prompt/response
|
853 |
+
'kastan/rlhf-qa-comparisons', # good QA
|
854 |
+
'kastan/rlhf-qa-conditional-generation-v2', # prompt answer
|
855 |
+
'OllieStanley/humaneval-mbpp-codegen-qa', # code QA, but started from words, so better than other code QA
|
856 |
+
'OllieStanley/humaneval-mbpp-testgen-qa', # code QA
|
857 |
+
'Graverman/Instruct-to-Code', # code QA
|
858 |
+
'openai/summarize_from_feedback', # summarize
|
859 |
+
'relbert/analogy_questions', # analogy QA
|
860 |
+
'yitingxie/rlhf-reward-datasets', # prompt, chosen, rejected.
|
861 |
+
'yizhongw/self_instruct', # instruct (super natural & instruct)
|
862 |
+
'HuggingFaceH4/asss', # QA, big A
|
863 |
+
'kastan/rlhf-qa-conditional-generation-v2', # QA
|
864 |
+
'cosmos_qa', # context QA
|
865 |
+
'vishal-burman/c4-faqs', # QA but not so much reasoning, but alot of text
|
866 |
+
'squadshifts', # QA from context
|
867 |
+
'hotpot_qa', # QA from context
|
868 |
+
'adversarial_qa', # QA from context
|
869 |
+
'allenai/soda', # dialog -> narrative/summary
|
870 |
+
'squad_v2', # context QA
|
871 |
+
'squadshifts', # context QA
|
872 |
+
'dferndz/cSQuAD1', # context QA
|
873 |
+
'dferndz/cSQuAD2', # context QA
|
874 |
+
'din0s/msmarco-nlgen', # context QA
|
875 |
+
'domenicrosati/TruthfulQA', # common sense truthful QA -- trivia but good trivia
|
876 |
+
'hotpot_qa', # context, QA
|
877 |
+
'HuggingFaceH4/self-instruct-eval', # instruct QA, medium quality, some language reasoning
|
878 |
+
'kastan/EE_QA_for_RLHF', # context QA
|
879 |
+
'KK04/LogicInference_OA', # instruction logical QA
|
880 |
+
'lmqg/qa_squadshifts_synthetic', # context QA
|
881 |
+
'lmqg/qg_squad', # context QA
|
882 |
+
'lmqg/qg_squadshifts', # context QA
|
883 |
+
'lmqg/qg_subjqa', # context QA
|
884 |
+
'pszemraj/HC3-textgen-qa', # QA medium, has human responses -- humans tend to provide links instead of trying to answer
|
885 |
+
'pythonist/newdata', # long context, QA, brief A
|
886 |
+
'ropes', # long background, situation, question, A
|
887 |
+
'wikitablequestions', # table -> QA
|
888 |
+
'bigscience/p3', # context QA but short answers
|
889 |
+
]
|
890 |
+
|
891 |
+
|
892 |
+
|
893 |
+
code_useful = ['0n1xus/codexglue',
|
894 |
+
'openai_humaneval',
|
895 |
+
'koutch/staqc',
|
896 |
+
]
|
897 |
+
|
898 |
+
|
899 |
+
maybe_useful = ['AlekseyKorshuk/comedy-scripts',
|
900 |
+
'openbookqa', # hard to parse, low reasoning
|
901 |
+
'qed', # reasonable QA, but low reasoning
|
902 |
+
'selqa', # candidate answers
|
903 |
+
'HuggingFaceH4/instruction-pilot-outputs-filtered',
|
904 |
+
'GBaker/MedQA-USMLE-4-options', # medical QA with long questions
|
905 |
+
'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
|
906 |
+
]
|
907 |
+
|
908 |
+
|
909 |
+
summary_useful = ['austin/rheum_abstracts',
|
910 |
+
'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
|
911 |
+
'CarperAI/openai_summarize_tldr', # summarize QA
|
912 |
+
'ccdv/cnn_dailymail', # summarize news
|
913 |
+
'ccdv/govreport-summarization', # summarize high quality
|
914 |
+
'ccdv/pubmed-summarization', # summarize high quality
|
915 |
+
'duorc', # plot -> QA
|
916 |
+
'farleyknight/big_patent_5_percent', # desc -> abstract
|
917 |
+
'multi_news', # summary
|
918 |
+
'opinosis',
|
919 |
+
'SophieTr/reddit_clean',
|
920 |
+
'allenai/mup', # long text -> summary
|
921 |
+
'allenai/multi_lexsum', # long text -> summary
|
922 |
+
'big_patent',
|
923 |
+
'allenai/wcep_dense_max',
|
924 |
+
'awinml/costco_long_practice',
|
925 |
+
'GEM/xsum',
|
926 |
+
'ratishsp/newshead',
|
927 |
+
'RussianNLP/wikiomnia', # russian
|
928 |
+
'stacked-summaries/stacked-xsum-1024',
|
929 |
+
]
|
930 |
+
|
931 |
+
|
932 |
+
math_useful = [
|
933 |
+
'competition_math'
|
934 |
+
]
|
935 |
+
|
936 |
+
|
937 |
+
skipped = ['c4', # maybe useful, used for flan, but skipped due to size
|
938 |
+
]
|
939 |
+
|
940 |
+
"""
|
941 |
+
To get training data from oig:
|
942 |
+
pytest test_oig test_grade_final test_finalize_to_json
|
943 |
+
"""
|
944 |
+
|
945 |
+
human = '<human>:'
|
946 |
+
bot = '<bot>:'
|
947 |
+
|
948 |
+
|
949 |
+
def test_assemble_and_detox():
|
950 |
+
import re
|
951 |
+
from profanity_check import predict_prob
|
952 |
+
df_list = []
|
953 |
+
for data in useful_oig_files:
|
954 |
+
print("Processing %s" % data, flush=True)
|
955 |
+
df = pd.read_parquet(data)
|
956 |
+
df = df.reset_index(drop=True)
|
957 |
+
# chop up into human/bot interactions of no more than 10kB per row
|
958 |
+
text_list = df[['text']].values.ravel().tolist()
|
959 |
+
new_text = []
|
960 |
+
max_len = 2048 # uber cutoff
|
961 |
+
MAX_LEN = 2048//2 - 30 # max len per question/answer
|
962 |
+
for text in tqdm(text_list):
|
963 |
+
human_starts = [m.start() for m in re.finditer('<human>: ', text)]
|
964 |
+
if len(human_starts) == 1:
|
965 |
+
human_starts = [0, len(text)] # always go into for loop below
|
966 |
+
blurb = ''
|
967 |
+
for i in range(len(human_starts) - 1):
|
968 |
+
interaction = text[human_starts[i]: human_starts[i+1]][:max_len]
|
969 |
+
blurb += interaction
|
970 |
+
if len(blurb) >= MAX_LEN:
|
971 |
+
blurb = get_sentences(blurb, length=MAX_LEN)[0]
|
972 |
+
new_text.append(blurb + "\n<human>:")
|
973 |
+
blurb = ''
|
974 |
+
if blurb:
|
975 |
+
blurb = get_sentences(blurb, length=MAX_LEN)[0]
|
976 |
+
new_text.append(blurb + "\n<human>:")
|
977 |
+
|
978 |
+
if len(new_text) > len(text_list):
|
979 |
+
print("Added %d new rows (before: %d)" % (len(new_text) - df.shape[0], df.shape[0]))
|
980 |
+
df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)})
|
981 |
+
df = df.drop_duplicates(keep='first')
|
982 |
+
print(df['text'].apply(lambda x: len(x)).describe())
|
983 |
+
assert df['text'].apply(lambda x: len(x)).max() <= 2 * max_len
|
984 |
+
|
985 |
+
# faster than better_profanity, do early
|
986 |
+
df['profanity'] = predict_prob(df['text'])
|
987 |
+
before_rows = df.shape[0]
|
988 |
+
df = df[df['profanity'] < 0.25] # drop any low quality stuff
|
989 |
+
after_rows = df.shape[0]
|
990 |
+
print("Dropped %d rows out of %d due to alt-profanity-check" % (before_rows - after_rows, before_rows))
|
991 |
+
df_list.append(df)
|
992 |
+
print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
|
993 |
+
print("So far have %d rows" % sum([len(x) for x in df_list]))
|
994 |
+
df_final = pd.concat(df_list)
|
995 |
+
df_final = df_final.sample(frac=1, random_state=1234).reset_index(drop=True)
|
996 |
+
df_final.to_parquet('h2oGPT.cleaned.human_bot.shorter.parquet', index=False)
|
997 |
+
|
998 |
+
|
999 |
+
def test_basic_cleaning():
|
1000 |
+
# from better_profanity import profanity
|
1001 |
+
# https://pypi.org/project/alt-profanity-check/
|
1002 |
+
from profanity_check import predict
|
1003 |
+
df_list = []
|
1004 |
+
for data in useful_oig_files:
|
1005 |
+
#for data in useful_oig_files[:5]:
|
1006 |
+
#for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
|
1007 |
+
print("Processing %s" % data, flush=True)
|
1008 |
+
df = pd.read_parquet(data)
|
1009 |
+
df = df.reset_index(drop=True)
|
1010 |
+
# NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
|
1011 |
+
#avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
|
1012 |
+
df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot))/2.0)
|
1013 |
+
df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
|
1014 |
+
#df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
|
1015 |
+
#low_quality_patterns = ['Write the rest of this wikipedia article']
|
1016 |
+
res = predict(df['text'])
|
1017 |
+
df['bad_words'] = res
|
1018 |
+
df = df.reset_index(drop=True)
|
1019 |
+
df = df[df['bad_words'] == 0]
|
1020 |
+
df = df[['text', 'avg_words', 'avg_bot_words']]
|
1021 |
+
df = df.drop_duplicates(keep='first')
|
1022 |
+
print(df[df['avg_words'] == df['avg_words'].max()]['text'].values)
|
1023 |
+
median_words = np.median(df['avg_words'])
|
1024 |
+
min_words_per_entity = max(30, 0.8 * median_words)
|
1025 |
+
max_words_per_entity = 2048 # too hard to learn from for now
|
1026 |
+
df = df[df['avg_words'] > min_words_per_entity]
|
1027 |
+
df = df[df['avg_words'] < max_words_per_entity]
|
1028 |
+
|
1029 |
+
min_words_per_entity = max(20, 0.5 * median_words) # bot should say stuff for now
|
1030 |
+
max_words_per_entity = 2048 # too hard to learn from for now
|
1031 |
+
df = df[df['avg_bot_words'] > min_words_per_entity]
|
1032 |
+
df = df[df['avg_bot_words'] < max_words_per_entity]
|
1033 |
+
|
1034 |
+
df_list.append(df)
|
1035 |
+
print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
|
1036 |
+
df_final = pd.concat(df_list)
|
1037 |
+
df_final.to_parquet('h2oGPT.cleaned.human_bot.parquet', index=False)
|
1038 |
+
|
1039 |
+
|
1040 |
+
from joblib import Parallel, delayed, effective_n_jobs
|
1041 |
+
from sklearn.utils import gen_even_slices
|
1042 |
+
from sklearn.utils.validation import _num_samples
|
1043 |
+
|
1044 |
+
|
1045 |
+
def parallel_apply(df, func, n_jobs=-1, **kwargs):
|
1046 |
+
""" Pandas apply in parallel using joblib.
|
1047 |
+
Uses sklearn.utils to partition input evenly.
|
1048 |
+
|
1049 |
+
Args:
|
1050 |
+
df: Pandas DataFrame, Series, or any other object that supports slicing and apply.
|
1051 |
+
func: Callable to apply
|
1052 |
+
n_jobs: Desired number of workers. Default value -1 means use all available cores.
|
1053 |
+
**kwargs: Any additional parameters will be supplied to the apply function
|
1054 |
+
|
1055 |
+
Returns:
|
1056 |
+
Same as for normal Pandas DataFrame.apply()
|
1057 |
+
|
1058 |
+
"""
|
1059 |
+
|
1060 |
+
if effective_n_jobs(n_jobs) == 1:
|
1061 |
+
return df.apply(func, **kwargs)
|
1062 |
+
else:
|
1063 |
+
ret = Parallel(n_jobs=n_jobs)(
|
1064 |
+
delayed(type(df).apply)(df[s], func, **kwargs)
|
1065 |
+
for s in gen_even_slices(_num_samples(df), effective_n_jobs(n_jobs)))
|
1066 |
+
return pd.concat(ret)
|
1067 |
+
|
1068 |
+
|
1069 |
+
def add_better_profanity_flag(df):
|
1070 |
+
from better_profanity import profanity
|
1071 |
+
df['better_profanity'] = parallel_apply(
|
1072 |
+
df['text'],
|
1073 |
+
lambda x: profanity.contains_profanity(x),
|
1074 |
+
n_jobs=-1,
|
1075 |
+
)
|
1076 |
+
return df
|
1077 |
+
|
1078 |
+
|
1079 |
+
def add_textstat_grade(df):
|
1080 |
+
import textstat
|
1081 |
+
|
1082 |
+
def myfunc(x):
|
1083 |
+
return textstat.flesch_kincaid_grade(x) # simple grade
|
1084 |
+
|
1085 |
+
if False:
|
1086 |
+
import dask.dataframe as dd
|
1087 |
+
# 40 seconds for 1000 rows, but have 1,787,799 rows
|
1088 |
+
ddata = dd.from_pandas(df, npartitions=120)
|
1089 |
+
|
1090 |
+
df['flesch_grade'] = ddata['text'].apply(myfunc).compute()
|
1091 |
+
if True:
|
1092 |
+
# fast way
|
1093 |
+
df['flesch_grade'] = parallel_apply(df['text'], myfunc, n_jobs=-1)
|
1094 |
+
return df
|
1095 |
+
|
1096 |
+
|
1097 |
+
def add_deberta_grade(df):
|
1098 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
1099 |
+
import torch
|
1100 |
+
reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
1101 |
+
rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(
|
1102 |
+
reward_name), AutoTokenizer.from_pretrained(reward_name)
|
1103 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
1104 |
+
rank_model.to(device)
|
1105 |
+
|
1106 |
+
def get_question(x):
|
1107 |
+
return x.replace('<human>: ', '').split('<bot>:')[0]
|
1108 |
+
|
1109 |
+
def get_answer(x):
|
1110 |
+
try:
|
1111 |
+
answer = x.split('<bot>: ')[1].split('<human>:')[0].replace('<bot>: ', '')
|
1112 |
+
except:
|
1113 |
+
answer = x.split('<bot>:')[1].split('<human>:')[0].replace('<bot>:', '')
|
1114 |
+
return answer
|
1115 |
+
|
1116 |
+
df['question'] = parallel_apply(df['text'], get_question, n_jobs=-1)
|
1117 |
+
df['answer'] = parallel_apply(df['text'], get_answer, n_jobs=-1)
|
1118 |
+
|
1119 |
+
from datasets import Dataset
|
1120 |
+
from transformers import pipeline
|
1121 |
+
from transformers.pipelines.pt_utils import KeyPairDataset
|
1122 |
+
import tqdm
|
1123 |
+
|
1124 |
+
pipe = pipeline(
|
1125 |
+
"text-classification",
|
1126 |
+
model=reward_name,
|
1127 |
+
device="cuda:0" if torch.cuda.is_available() else "cpu"
|
1128 |
+
)
|
1129 |
+
start = 0
|
1130 |
+
batch_size = 64 * 16
|
1131 |
+
micro_batch = orig_micro_batch = 16
|
1132 |
+
end = 0
|
1133 |
+
import socket
|
1134 |
+
checkpoint = "grades.%s.pkl" % socket.gethostname()
|
1135 |
+
grades = []
|
1136 |
+
import pickle
|
1137 |
+
if os.path.exists(checkpoint):
|
1138 |
+
with open(checkpoint, "rb") as f:
|
1139 |
+
start, grades = pickle.loads(f.read())
|
1140 |
+
last_oom = 0
|
1141 |
+
while end < df.shape[0]:
|
1142 |
+
# manual batching to handle OOM more gracefully
|
1143 |
+
end = min(start + batch_size, df.shape[0])
|
1144 |
+
if start == end:
|
1145 |
+
break
|
1146 |
+
dataset = Dataset.from_pandas(df.iloc[start:end, :])
|
1147 |
+
try:
|
1148 |
+
grades.extend([
|
1149 |
+
x['score'] for x in tqdm.tqdm(
|
1150 |
+
pipe(KeyPairDataset(dataset, "question", "answer"), batch_size=micro_batch)
|
1151 |
+
)
|
1152 |
+
])
|
1153 |
+
except torch.cuda.OutOfMemoryError:
|
1154 |
+
last_oom = start
|
1155 |
+
micro_batch = max(1, micro_batch // 2)
|
1156 |
+
print("OOM - retrying with micro_batch=%d" % micro_batch)
|
1157 |
+
continue
|
1158 |
+
if last_oom == start:
|
1159 |
+
micro_batch = orig_micro_batch
|
1160 |
+
print("Returning to micro_batch=%d" % micro_batch)
|
1161 |
+
assert len(grades) == end
|
1162 |
+
start = end
|
1163 |
+
with open(checkpoint, "wb") as f:
|
1164 |
+
f.write(pickle.dumps((end, grades)))
|
1165 |
+
print("%d/%d" % (end, df.shape[0]))
|
1166 |
+
df['grade_deberta'] = grades
|
1167 |
+
if os.path.exists(checkpoint):
|
1168 |
+
os.remove(checkpoint)
|
1169 |
+
return df
|
1170 |
+
|
1171 |
+
|
1172 |
+
def test_chop_by_lengths():
|
1173 |
+
file = "h2oGPT.cleaned.human_bot.shorter.parquet"
|
1174 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1175 |
+
df = count_human_bot_lengths(df)
|
1176 |
+
df['rand'] = np.random.rand(df.shape[0])
|
1177 |
+
df['rand2'] = np.random.rand(df.shape[0])
|
1178 |
+
before_rows = df.shape[0]
|
1179 |
+
# throw away short human/bot responses with higher likelihood
|
1180 |
+
df = df[(df['len_human_mean'] > 20)] # never keep very short ones
|
1181 |
+
df = df[(df['len_human_mean'] > 30) | (df['rand'] < 0.2)]
|
1182 |
+
df = df[(df['len_human_mean'] > 50) | (df['rand'] < 0.5)]
|
1183 |
+
df = df[(df['len_human_max'] < 10000)] # drop super long (basically only human) ones
|
1184 |
+
df = df[(df['len_bot_mean'] > 20)] # never keep very short ones
|
1185 |
+
df = df[(df['len_bot_mean'] > 30) | (df['rand2'] < 0.2)]
|
1186 |
+
df = df[(df['len_bot_mean'] > 50) | (df['rand2'] < 0.5)]
|
1187 |
+
df = df[(df['len_bot_max'] < 10000)] # drop super long (only bot) ones
|
1188 |
+
assert df['text'].apply(lambda x: len(x)).max() < 20000
|
1189 |
+
df = df.drop(['rand', 'rand2'], axis=1)
|
1190 |
+
after_rows = df.shape[0]
|
1191 |
+
print("Chopped off %d out of %d rows due to length" % (before_rows - after_rows, before_rows))
|
1192 |
+
print(df.describe())
|
1193 |
+
df.to_parquet('h2oGPT.cleaned.chopped.human_bot.shorter.parquet', index=False)
|
1194 |
+
|
1195 |
+
|
1196 |
+
def count_human_bot_lengths(df, human=None, bot=None):
|
1197 |
+
import re
|
1198 |
+
len_human_min = []
|
1199 |
+
len_human_max = []
|
1200 |
+
len_human_mean = []
|
1201 |
+
len_bot_min = []
|
1202 |
+
len_bot_max = []
|
1203 |
+
len_bot_mean = []
|
1204 |
+
human = human or '<human>:'
|
1205 |
+
bot = bot or '<bot>:'
|
1206 |
+
for is_human in [True, False]:
|
1207 |
+
what = human if is_human else bot
|
1208 |
+
other = human if not is_human else bot
|
1209 |
+
for i in range(df.shape[0]):
|
1210 |
+
text = df.loc[i, 'text']
|
1211 |
+
assert isinstance(text, str)
|
1212 |
+
starts = [m.start() for m in re.finditer(what, text)]
|
1213 |
+
if len(starts) == 1:
|
1214 |
+
starts = [starts[0], len(text)] # always go into for loop below
|
1215 |
+
assert len(text)
|
1216 |
+
list_what = []
|
1217 |
+
for ii in range(len(starts) - 1):
|
1218 |
+
interaction = text[starts[ii]: starts[ii+1]]
|
1219 |
+
if other in interaction:
|
1220 |
+
interaction = interaction[:interaction.find(other)]
|
1221 |
+
interaction.strip()
|
1222 |
+
list_what.append(interaction)
|
1223 |
+
if not list_what:
|
1224 |
+
list_what = [''] # handle corrupted data, very rare, leads to sizes 0
|
1225 |
+
if is_human:
|
1226 |
+
len_human_min.append(min([len(x) for x in list_what]))
|
1227 |
+
len_human_max.append(max([len(x) for x in list_what]))
|
1228 |
+
len_human_mean.append(np.mean([len(x) for x in list_what]))
|
1229 |
+
else:
|
1230 |
+
len_bot_min.append(min([len(x) for x in list_what]))
|
1231 |
+
len_bot_max.append(max([len(x) for x in list_what]))
|
1232 |
+
len_bot_mean.append(np.mean([len(x) for x in list_what]))
|
1233 |
+
df['len_human_min'] = len_human_min
|
1234 |
+
df['len_human_max'] = len_human_max
|
1235 |
+
df['len_human_mean'] = len_human_mean
|
1236 |
+
df['len_bot_min'] = len_bot_min
|
1237 |
+
df['len_bot_max'] = len_bot_max
|
1238 |
+
df['len_bot_mean'] = len_bot_mean
|
1239 |
+
np.random.seed(1234)
|
1240 |
+
pd.set_option('display.max_columns', None)
|
1241 |
+
print("Before chopping")
|
1242 |
+
print(df.describe())
|
1243 |
+
return df
|
1244 |
+
|
1245 |
+
|
1246 |
+
def test_grade():
|
1247 |
+
df = None
|
1248 |
+
|
1249 |
+
file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet"
|
1250 |
+
output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet"
|
1251 |
+
if not os.path.exists(output_file):
|
1252 |
+
if df is None:
|
1253 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1254 |
+
df = add_textstat_grade(df)
|
1255 |
+
min_grade = 10
|
1256 |
+
max_grade = 25
|
1257 |
+
df = df[df['flesch_grade'] >= min_grade]
|
1258 |
+
df = df[df['flesch_grade'] <= max_grade]
|
1259 |
+
print("After Flesch grade")
|
1260 |
+
print(df.describe())
|
1261 |
+
df.to_parquet(output_file, index=False)
|
1262 |
+
|
1263 |
+
file = output_file
|
1264 |
+
output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet"
|
1265 |
+
if not os.path.exists(output_file):
|
1266 |
+
# slower than alt-profanity, do last, but do before deberta grading, since that's slower
|
1267 |
+
if df is None:
|
1268 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1269 |
+
df = add_better_profanity_flag(df)
|
1270 |
+
before_rows = df.shape[0]
|
1271 |
+
df = df[df['better_profanity'] == 0]
|
1272 |
+
df = df.drop(['better_profanity'], axis=1)
|
1273 |
+
after_rows = df.shape[0]
|
1274 |
+
print("Dropped %d rows out of %d due to better_profanity" % (before_rows - after_rows, before_rows))
|
1275 |
+
print(df.describe())
|
1276 |
+
df.to_parquet(output_file, index=False)
|
1277 |
+
|
1278 |
+
file = output_file
|
1279 |
+
output_file = 'h2oGPT.cleaned.graded3.human_bot.shorter.parquet'
|
1280 |
+
if not os.path.exists(output_file):
|
1281 |
+
if df is None:
|
1282 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1283 |
+
df = add_deberta_grade(df)
|
1284 |
+
min_grade = 0.3
|
1285 |
+
max_grade = np.inf
|
1286 |
+
before_rows = df.shape[0]
|
1287 |
+
df = df[df['grade_deberta'] >= min_grade]
|
1288 |
+
df = df[df['grade_deberta'] <= max_grade]
|
1289 |
+
after_rows = df.shape[0]
|
1290 |
+
print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
|
1291 |
+
print("After DeBERTa grade")
|
1292 |
+
print(df.describe())
|
1293 |
+
df.to_parquet(output_file, index=False)
|
1294 |
+
|
1295 |
+
file = output_file
|
1296 |
+
output_file = 'h2oGPT.cleaned.graded.human_bot.shorter.parquet'
|
1297 |
+
if df is None:
|
1298 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1299 |
+
df.to_parquet(output_file, index=False)
|
1300 |
+
|
1301 |
+
|
1302 |
+
@pytest.mark.parametrize(
|
1303 |
+
"fixup_personality, only_personality, deberta_grading",
|
1304 |
+
[
|
1305 |
+
[False, False, False],
|
1306 |
+
[True, True, False],
|
1307 |
+
[True, False, False],
|
1308 |
+
[True, False, True],
|
1309 |
+
]
|
1310 |
+
)
|
1311 |
+
def test_add_open_assistant(fixup_personality, only_personality, deberta_grading, save_json=True):
|
1312 |
+
"""
|
1313 |
+
Flatten tree structure into one row per path from root to leaf
|
1314 |
+
Also turn into human_bot prompting format:
|
1315 |
+
<human>: question\n<bot>: answer <human>: question2\n<bot>: answer2 Etc.
|
1316 |
+
Also saves a .json locally as side-effect
|
1317 |
+
returns list of dicts, containing intput, prompt_type and source
|
1318 |
+
"""
|
1319 |
+
from datasets import load_dataset
|
1320 |
+
data_file = "OpenAssistant/oasst1"
|
1321 |
+
ds = load_dataset(data_file)
|
1322 |
+
df = pd.concat([ds['train'].to_pandas(), ds['validation'].to_pandas()], axis=0)
|
1323 |
+
rows = {}
|
1324 |
+
message_ids = df['message_id'].values.tolist()
|
1325 |
+
message_tree_ids = df['message_tree_id'].values.tolist()
|
1326 |
+
parent_ids = df['parent_id'].values.tolist()
|
1327 |
+
texts = df['text'].values.tolist()
|
1328 |
+
roles = df['role'].values.tolist()
|
1329 |
+
|
1330 |
+
for i in range(df.shape[0]):
|
1331 |
+
# collect all trees
|
1332 |
+
message_id = message_ids[i]
|
1333 |
+
message_tree_id = message_tree_ids[i]
|
1334 |
+
parent_id = parent_ids[i]
|
1335 |
+
text = texts[i]
|
1336 |
+
if fixup_personality:
|
1337 |
+
text = text.replace("Open Assistant", "h2oGPT")
|
1338 |
+
text = text.replace("Open-Assistant", "h2oGPT")
|
1339 |
+
text = text.replace("open-assistant", "h2oGPT")
|
1340 |
+
text = text.replace("OpenAssistant", "h2oGPT")
|
1341 |
+
text = text.replace("open assistant", "h2oGPT")
|
1342 |
+
text = text.replace("Open Assistand", "h2oGPT")
|
1343 |
+
text = text.replace("Open Assitant", "h2oGPT")
|
1344 |
+
text = text.replace("Open Assistent", "h2oGPT")
|
1345 |
+
text = text.replace("Open Assisstant", "h2oGPT")
|
1346 |
+
text = text.replace("Open Assitent", "h2oGPT")
|
1347 |
+
text = text.replace("Open Assitiant", "h2oGPT")
|
1348 |
+
text = text.replace("Open Assistiant", "h2oGPT")
|
1349 |
+
text = text.replace("Open Assitan ", "h2oGPT ")
|
1350 |
+
text = text.replace("Open Assistan ", "h2oGPT ")
|
1351 |
+
text = text.replace("Open Asistant", "h2oGPT")
|
1352 |
+
text = text.replace("Open Assiant", "h2oGPT")
|
1353 |
+
text = text.replace("Assistant", "h2oGPT")
|
1354 |
+
text = text.replace("LAION AI", "H2O.ai")
|
1355 |
+
text = text.replace("LAION-AI", "H2O.ai")
|
1356 |
+
text = text.replace("LAION,", "H2O.ai,")
|
1357 |
+
text = text.replace("LAION.ai", "H2O.ai")
|
1358 |
+
text = text.replace("LAION.", "H2O.ai.")
|
1359 |
+
text = text.replace("LAION", "H2O.ai")
|
1360 |
+
|
1361 |
+
role = roles[i]
|
1362 |
+
new_data = ('<human>: ' if role == 'prompter' else '<bot>: ') + text
|
1363 |
+
entry = dict(message_id=message_id, parent_id=parent_id, text=new_data)
|
1364 |
+
if message_tree_id not in rows:
|
1365 |
+
rows[message_tree_id] = [entry]
|
1366 |
+
else:
|
1367 |
+
rows[message_tree_id].append(entry)
|
1368 |
+
|
1369 |
+
all_rows = []
|
1370 |
+
|
1371 |
+
for node_id in rows:
|
1372 |
+
# order responses in tree, based on message/parent relationship
|
1373 |
+
conversations = []
|
1374 |
+
|
1375 |
+
list_msgs = rows[node_id]
|
1376 |
+
# find start
|
1377 |
+
while len(list_msgs):
|
1378 |
+
for i, leaf in enumerate(list_msgs):
|
1379 |
+
found = False
|
1380 |
+
parent_id = leaf['parent_id']
|
1381 |
+
if parent_id is None:
|
1382 |
+
# conversation starter
|
1383 |
+
conversations.append(leaf)
|
1384 |
+
found = True
|
1385 |
+
else:
|
1386 |
+
for conv in conversations:
|
1387 |
+
# find all conversations to add my message to
|
1388 |
+
if parent_id in conv['message_id'] and parent_id != conv['message_id'][-len(parent_id):]:
|
1389 |
+
# my message doesn't follow conversation
|
1390 |
+
continue
|
1391 |
+
if parent_id == conv['message_id'][-len(parent_id):]:
|
1392 |
+
# my message follows conversation, but fork first, so another follow-on message can do same
|
1393 |
+
conversations.append(conv.copy())
|
1394 |
+
conv['text'] += f"""
|
1395 |
+
{leaf['text']}
|
1396 |
+
"""
|
1397 |
+
conv['message_id'] += leaf['message_id']
|
1398 |
+
found = True
|
1399 |
+
break
|
1400 |
+
if found:
|
1401 |
+
# my content was used, so nuke from list
|
1402 |
+
del list_msgs[i]
|
1403 |
+
break
|
1404 |
+
|
1405 |
+
# now reduce down to final conversations, find the longest chains of message ids
|
1406 |
+
for i, conv in enumerate(conversations):
|
1407 |
+
for j, conv2 in enumerate(conversations):
|
1408 |
+
if i == j:
|
1409 |
+
continue
|
1410 |
+
if conv['message_id'] and conv2['message_id']:
|
1411 |
+
assert conv['message_id'] != conv2['message_id']
|
1412 |
+
# delete the shorter conversation, if one contains the other
|
1413 |
+
if conv['message_id'] in conv2['message_id']:
|
1414 |
+
conv['message_id'] = None
|
1415 |
+
if conv2['message_id'] in conv['message_id']:
|
1416 |
+
conv2['message_id'] = None
|
1417 |
+
conversations = [c for c in conversations if c['message_id']]
|
1418 |
+
if only_personality:
|
1419 |
+
all_rows.extend([dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if 'h2oGPT' in c['text']])
|
1420 |
+
else:
|
1421 |
+
all_rows.extend([dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if "What is H2O.ai" not in c['text']])
|
1422 |
+
unhelpful = get_unhelpful_list()
|
1423 |
+
all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
|
1424 |
+
personality = create_personality_data()
|
1425 |
+
all_rows.extend(personality * 10)
|
1426 |
+
np.random.seed(123)
|
1427 |
+
np.random.shuffle(all_rows)
|
1428 |
+
print(len(all_rows))
|
1429 |
+
if deberta_grading:
|
1430 |
+
df = pd.DataFrame(all_rows)
|
1431 |
+
df = df.rename(columns={'input': 'text'})
|
1432 |
+
df = add_deberta_grade(df)
|
1433 |
+
df = df.rename(columns={'text': 'input'})
|
1434 |
+
drop = True
|
1435 |
+
if drop:
|
1436 |
+
min_grade = 0.3
|
1437 |
+
max_grade = np.inf
|
1438 |
+
before_rows = df.shape[0]
|
1439 |
+
df = df[df['grade_deberta'] >= min_grade]
|
1440 |
+
df = df[df['grade_deberta'] <= max_grade]
|
1441 |
+
after_rows = df.shape[0]
|
1442 |
+
print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
|
1443 |
+
print("After DeBERTa grade")
|
1444 |
+
print(df.describe())
|
1445 |
+
all_rows = []
|
1446 |
+
for i in range(df.shape[0]):
|
1447 |
+
all_rows.append(
|
1448 |
+
dict(
|
1449 |
+
input=df['input'].iloc[i],
|
1450 |
+
source=df['source'].iloc[i],
|
1451 |
+
prompt_type=df['prompt_type'].iloc[i],
|
1452 |
+
grade_deberta=df['grade_deberta'].iloc[i],
|
1453 |
+
)
|
1454 |
+
)
|
1455 |
+
if save_json:
|
1456 |
+
data_file = data_file + \
|
1457 |
+
("_h2ogpt" if fixup_personality else "") + \
|
1458 |
+
("_only" if only_personality else "") + \
|
1459 |
+
("_graded" if deberta_grading else "")
|
1460 |
+
for i in range(len(all_rows)):
|
1461 |
+
all_rows[i]['id'] = i
|
1462 |
+
with open(data_file.lower().replace("/", "_") + ".json", "w") as f:
|
1463 |
+
f.write(json.dumps(all_rows, indent=2))
|
1464 |
+
return all_rows
|
1465 |
+
|
1466 |
+
|
1467 |
+
def test_finalize_to_json():
|
1468 |
+
df = pd.read_parquet('h2oGPT.cleaned.graded.human_bot.shorter.parquet')
|
1469 |
+
df = df.rename(columns={'text': 'input'})
|
1470 |
+
|
1471 |
+
print("Number of high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1472 |
+
|
1473 |
+
print("Adding open assistant data")
|
1474 |
+
with open("openassistant_oasst1_h2ogpt_graded.json") as f:
|
1475 |
+
open_assistant = json.loads(f.read())
|
1476 |
+
df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0)
|
1477 |
+
|
1478 |
+
def final_clean(df):
|
1479 |
+
from better_profanity import profanity
|
1480 |
+
profanity.load_censor_words_from_file("data/censor_words.txt")
|
1481 |
+
df['profanity'] = parallel_apply(
|
1482 |
+
df['input'],
|
1483 |
+
lambda x: profanity.contains_profanity(x),
|
1484 |
+
n_jobs=-1,
|
1485 |
+
)
|
1486 |
+
return df[(df['profanity'] == 0)].reset_index(drop=True)
|
1487 |
+
print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1488 |
+
df = final_clean(df)
|
1489 |
+
print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1490 |
+
print(df.describe())
|
1491 |
+
print(df.shape)
|
1492 |
+
row_list = []
|
1493 |
+
for i in range(df.shape[0]):
|
1494 |
+
row_list.append(
|
1495 |
+
dict(
|
1496 |
+
input=df.loc[i, 'input'],
|
1497 |
+
source=df.loc[i, 'source'],
|
1498 |
+
prompt_type='plain',
|
1499 |
+
)
|
1500 |
+
)
|
1501 |
+
np.random.seed(1234)
|
1502 |
+
np.random.shuffle(row_list)
|
1503 |
+
unhelpful = get_unhelpful_list()
|
1504 |
+
row_list = [x for x in row_list if not any(u in x['input'] for u in unhelpful)]
|
1505 |
+
for i in range(len(row_list)):
|
1506 |
+
row_list[i]['id'] = i
|
1507 |
+
row_list[i]['input'] = row_list[i]['input'].replace(" <bot>:", "\n<bot>:")
|
1508 |
+
with open('h2ogpt-oig-oasst1-instruct-cleaned-v3.json', "w") as f:
|
1509 |
+
f.write(json.dumps(row_list, indent=2))
|
1510 |
+
|
1511 |
+
|
1512 |
+
def create_personality_data():
|
1513 |
+
questions = [
|
1514 |
+
"What's your name?",
|
1515 |
+
"What is your name?",
|
1516 |
+
"What are you?",
|
1517 |
+
"Who are you?",
|
1518 |
+
"Do you have a name?",
|
1519 |
+
"Who trained you?",
|
1520 |
+
"Who created you?",
|
1521 |
+
"Who made you?",
|
1522 |
+
]
|
1523 |
+
answers = [
|
1524 |
+
"I'm h2oGPT, a large language model by H2O.ai.",
|
1525 |
+
"I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1526 |
+
"My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1527 |
+
"My name is h2oGPT. I'm a large language model trained by H2O.ai.",
|
1528 |
+
"Hi! I'm h2oGPT, a large language model by H2O.ai.",
|
1529 |
+
"Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1530 |
+
]
|
1531 |
+
help = [
|
1532 |
+
"",
|
1533 |
+
" How can I help you?",
|
1534 |
+
" How may I assist you?",
|
1535 |
+
" Nice to meet you.",
|
1536 |
+
]
|
1537 |
+
import itertools
|
1538 |
+
rows = []
|
1539 |
+
for pair in itertools.product(questions, answers, help):
|
1540 |
+
rows.append(
|
1541 |
+
dict(input=f"<human>: {pair[0]}\n<bot>: {pair[1]}{pair[2]}\n<human>:", prompt_type='plain', source="H2O.ai")
|
1542 |
+
)
|
1543 |
+
for row in [
|
1544 |
+
"<human>: What is H2O.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1545 |
+
"<human>: What is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1546 |
+
"<human>: What is H2O?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1547 |
+
"<human>: Who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1548 |
+
"<human>: who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1549 |
+
"<human>: who is h2o?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
|
1550 |
+
"<human>: What is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1551 |
+
"<human>: Who is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1552 |
+
"<human>: Who is H2O?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1553 |
+
"<human>: Who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1554 |
+
"<human>: who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
|
1555 |
+
]:
|
1556 |
+
rows.append(dict(input=row, prompt_type='plain', source='H2O.ai'))
|
1557 |
+
print(len(rows))
|
1558 |
+
with open("h2ogpt-personality.json", "w") as f:
|
1559 |
+
f.write(json.dumps(rows, indent=2))
|
1560 |
+
return rows
|
1561 |
+
|
1562 |
+
|
1563 |
+
def test_check_stats_data():
|
1564 |
+
filename = 'h2ogpt-oig-oasst1-instruct-cleaned-v3.json'
|
1565 |
+
df = pd.read_json(filename)
|
1566 |
+
|
1567 |
+
# get word stats
|
1568 |
+
df['char_count'] = df['input'].apply(lambda x: len(x))
|
1569 |
+
import matplotlib.pyplot as plt
|
1570 |
+
plt.figure(figsize=(10, 10))
|
1571 |
+
plt.hist(df['char_count'], bins=100)
|
1572 |
+
chars_avg = np.mean(df['char_count'])
|
1573 |
+
chars_median = np.median(df['char_count'])
|
1574 |
+
plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median))
|
1575 |
+
plt.savefig('chars_hist.png')
|
1576 |
+
plt.close()
|
1577 |
+
|
1578 |
+
# get tokenize stats for random sample of 1000 rows
|
1579 |
+
from finetune import generate_and_tokenize_prompt
|
1580 |
+
from loaders import get_loaders, get_tokenizer
|
1581 |
+
from functools import partial
|
1582 |
+
|
1583 |
+
llama_type = False
|
1584 |
+
tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
1585 |
+
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
|
1586 |
+
local_files_only = False
|
1587 |
+
resume_download = True
|
1588 |
+
use_auth_token = False
|
1589 |
+
tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
|
1590 |
+
prompt_type = 'plain' # trained with data already in human bot form
|
1591 |
+
train_on_inputs = True
|
1592 |
+
add_eos_token = False
|
1593 |
+
cutoff_len = 512 # can choose 2048
|
1594 |
+
generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
|
1595 |
+
train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
|
1596 |
+
cutoff_len=cutoff_len, tokenizer=tokenizer)
|
1597 |
+
from datasets import load_dataset
|
1598 |
+
data = load_dataset("json", data_files={"train": filename})
|
1599 |
+
val_set_size = 0.90
|
1600 |
+
train_val = data["train"].train_test_split(
|
1601 |
+
test_size=val_set_size, shuffle=True, seed=42
|
1602 |
+
)
|
1603 |
+
train_data = train_val["train"]
|
1604 |
+
train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count())
|
1605 |
+
|
1606 |
+
df_tokens = pd.DataFrame([len(x) for x in train_data['input_ids']], columns=['token_count'])
|
1607 |
+
|
1608 |
+
plt.figure(figsize=(10, 10))
|
1609 |
+
plt.hist(df_tokens['token_count'], bins=100)
|
1610 |
+
token_avg = np.mean(df_tokens['token_count'])
|
1611 |
+
token_median = np.median(df_tokens['token_count'])
|
1612 |
+
plt.title("token_count with cutoff=%s avg: %s median: %s" % (cutoff_len, token_avg, token_median))
|
1613 |
+
plt.savefig('token_hist_%s.png' % cutoff_len)
|
1614 |
+
plt.close()
|
1615 |
+
|
1616 |
+
|
1617 |
+
def get_unhelpful_list():
|
1618 |
+
# base versions
|
1619 |
+
unhelpful = ["I'm sorry, I didn't quite understand your question, could you please rephrase it?",
|
1620 |
+
"I'm sorry, but I don't understand your question. Could you please rephrase it?",
|
1621 |
+
"I'm sorry, I don't quite understand your question",
|
1622 |
+
"I'm sorry, I don't know",
|
1623 |
+
"I'm sorry, but I don't know",
|
1624 |
+
"I don't know anything",
|
1625 |
+
"I do not know",
|
1626 |
+
"I don't know",
|
1627 |
+
"I don't know how",
|
1628 |
+
"I do not know how",
|
1629 |
+
"Can you please explain what you mean",
|
1630 |
+
"please explain what you mean",
|
1631 |
+
"please explain",
|
1632 |
+
"I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by",
|
1633 |
+
"I'm sorry but I don't understand what you mean",
|
1634 |
+
"I don't understand",
|
1635 |
+
"I don't have the ability",
|
1636 |
+
"I do not have the ability",
|
1637 |
+
"I do not have",
|
1638 |
+
"I am a language model,",
|
1639 |
+
"I am a large language model,",
|
1640 |
+
"I do not understand your question. Can you please try to make it clearer?",
|
1641 |
+
"I'm sorry, but as an AI language model",
|
1642 |
+
"I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.",
|
1643 |
+
"I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?",
|
1644 |
+
"Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t",
|
1645 |
+
"I apologize, but I cannot perform the task you have requested.",
|
1646 |
+
"I'm sorry, I cannot perform this task as I am an AI language model and do not have access",
|
1647 |
+
"I'm sorry, I'm not sure what you're asking for here.",
|
1648 |
+
"I'm not sure what you are asking",
|
1649 |
+
"You need to provide more context",
|
1650 |
+
]
|
1651 |
+
# reduced versions, with redundant parts, just to give context for where they came from
|
1652 |
+
unhelpful += ["sorry, I didn't quite understand your question",
|
1653 |
+
"I didn't quite understand your question",
|
1654 |
+
"I didn't understand your question",
|
1655 |
+
"I did not understand your question",
|
1656 |
+
"I did not understand the question",
|
1657 |
+
"could you please rephrase"
|
1658 |
+
"could you rephrase"
|
1659 |
+
"I do not understand your question.",
|
1660 |
+
"I do not understand the question.",
|
1661 |
+
"I do not understand that question.",
|
1662 |
+
"Can you please try to make it clearer",
|
1663 |
+
"Can you try to make it clearer",
|
1664 |
+
"sorry, but as an AI language model",
|
1665 |
+
"as an AI language model",
|
1666 |
+
"I apologize, but I cannot",
|
1667 |
+
"I cannot rephrase text",
|
1668 |
+
"I cannot understand. Your post is difficult to read and follow."
|
1669 |
+
"Your post is difficult to read and follow."
|
1670 |
+
"I apologize, but I am",
|
1671 |
+
"Sorry, but I am not ",
|
1672 |
+
"nor am I capable",
|
1673 |
+
"I am not capable of",
|
1674 |
+
"I apologize, but I cannot perform the task you have requested",
|
1675 |
+
"I cannot perform the task",
|
1676 |
+
"I cannot complete the task",
|
1677 |
+
"I'm sorry",
|
1678 |
+
"I am sorry",
|
1679 |
+
"do not have access",
|
1680 |
+
"not sure what you're asking for",
|
1681 |
+
"not sure what you are asking for",
|
1682 |
+
"not sure what is being asked",
|
1683 |
+
"I'm not sure what you are asking",
|
1684 |
+
"not sure what you are asking",
|
1685 |
+
"You need to provide more context",
|
1686 |
+
"provide more context",
|
1687 |
+
]
|
1688 |
+
unhelpful += ["As a large language model",
|
1689 |
+
"cannot provide any information",
|
1690 |
+
"As an artificial intelligence I do not have the capability",
|
1691 |
+
"As an artificial intelligence I don't have the capability",
|
1692 |
+
"As an artificial intelligence I can't",
|
1693 |
+
"As an artificial intelligence I cannot",
|
1694 |
+
"I am sorry but I do not understand",
|
1695 |
+
"Can you please explain",
|
1696 |
+
"(sorry couldn't resist)",
|
1697 |
+
"(sorry could not resist)",
|
1698 |
+
" :)",
|
1699 |
+
" ;)",
|
1700 |
+
" :-)",
|
1701 |
+
" ;-)",
|
1702 |
+
" lol ",
|
1703 |
+
"Thanks so much!!!",
|
1704 |
+
"Thank You :)!!!",
|
1705 |
+
"Please try not to repeat",
|
1706 |
+
"I am an AI language model",
|
1707 |
+
"I'm a AI assistant that",
|
1708 |
+
"I'm an AI assistant that",
|
1709 |
+
"I am an AI assistant that",
|
1710 |
+
"etc.",
|
1711 |
+
"etc.etc.",
|
1712 |
+
"etc. etc.",
|
1713 |
+
"etc etc",
|
1714 |
+
]
|
1715 |
+
return unhelpful
|
1716 |
+
|
1717 |
+
|
1718 |
+
def test_check_unhelpful():
|
1719 |
+
# file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json'
|
1720 |
+
file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json'
|
1721 |
+
# file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
|
1722 |
+
|
1723 |
+
unhelpful = get_unhelpful_list()
|
1724 |
+
#data = json.load(open(file, 'rt'))
|
1725 |
+
df = pd.read_json(file)
|
1726 |
+
|
1727 |
+
use_reward_score_threshold = False
|
1728 |
+
use_bleu_threshold = False
|
1729 |
+
use_sentence_sim = True
|
1730 |
+
|
1731 |
+
from sacrebleu.metrics import BLEU
|
1732 |
+
bleu = BLEU()
|
1733 |
+
from nltk.translate.bleu_score import sentence_bleu
|
1734 |
+
|
1735 |
+
def get_bleu(actual, expected_list):
|
1736 |
+
#return bleu.sentence_score(actual, expected_list).score
|
1737 |
+
return sentence_bleu(expected_list, actual)
|
1738 |
+
|
1739 |
+
threshold = 0.0
|
1740 |
+
if use_reward_score_threshold:
|
1741 |
+
df = df[df['grade_deberta'] > threshold]
|
1742 |
+
|
1743 |
+
# back to as if original json load
|
1744 |
+
data = df.to_dict(orient='records')
|
1745 |
+
bads = {}
|
1746 |
+
string_all = str(data)
|
1747 |
+
for sub in unhelpful:
|
1748 |
+
bads[sub] = string_all.count(sub)
|
1749 |
+
bads = {k: v for k, v in bads.items() if v > 0}
|
1750 |
+
import pprint
|
1751 |
+
pp = pprint.PrettyPrinter(indent=4)
|
1752 |
+
pp.pprint(bads)
|
1753 |
+
|
1754 |
+
total_bads = sum(list(bads.values()))
|
1755 |
+
print('total_bads: %s' % total_bads, flush=True)
|
1756 |
+
|
1757 |
+
# check just bot
|
1758 |
+
import re
|
1759 |
+
convs = [[x.strip() for x in re.split(r'%s|%s' % (human, bot), y['input']) if x.strip()] for y in data]
|
1760 |
+
humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs]
|
1761 |
+
bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs]
|
1762 |
+
|
1763 |
+
# FIXME: apply back to json etc., just see for now
|
1764 |
+
bleu_threshold = 0.9
|
1765 |
+
if use_bleu_threshold:
|
1766 |
+
bots = [[x for x in y if get_bleu(x, unhelpful) < bleu_threshold] for y in tqdm(bots)]
|
1767 |
+
|
1768 |
+
cosine_sim_threshold = 0.8
|
1769 |
+
if use_sentence_sim:
|
1770 |
+
# pip install sentence_transformers-2.2.2
|
1771 |
+
from sentence_transformers import SentenceTransformer
|
1772 |
+
# sent_model = 'bert-base-nli-mean-tokens'
|
1773 |
+
#sent_model = 'nli-distilroberta-base-v2'
|
1774 |
+
sent_model = 'all-MiniLM-L6-v2'
|
1775 |
+
model = SentenceTransformer(sent_model)
|
1776 |
+
sentence_embeddings = model.encode(unhelpful)
|
1777 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
1778 |
+
bots = [x for x in tqdm(bots) if np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
|
1779 |
+
|
1780 |
+
bads_bots = {}
|
1781 |
+
string_all = str(bots)
|
1782 |
+
for sub in unhelpful:
|
1783 |
+
bads_bots[sub] = string_all.count(sub)
|
1784 |
+
bads_bots = {k: v for k, v in bads_bots.items() if v > 0}
|
1785 |
+
import pprint
|
1786 |
+
pp = pprint.PrettyPrinter(indent=4)
|
1787 |
+
pp.pprint(bads_bots)
|
1788 |
+
|
1789 |
+
total_bads_bots = sum(list(bads_bots.values()))
|
1790 |
+
print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
|
1791 |
+
|
1792 |
+
# assert len(bads) == 0, bads
|
1793 |
+
assert len(bads_bots) == 0, bads_bots
|
1794 |
+
|
1795 |
+
|
1796 |
+
def test_fortune2000_personalized():
|
1797 |
+
row_list = []
|
1798 |
+
import glob
|
1799 |
+
if not os.path.isdir("wikitext"):
|
1800 |
+
raise RuntimeError("download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip")
|
1801 |
+
for file in glob.glob("wikitext/*.txt"):
|
1802 |
+
with open(file, "r") as f:
|
1803 |
+
blob = f.read()
|
1804 |
+
N = 512 * 4
|
1805 |
+
row_list.extend([{'input': s, 'prompt_type': 'plain', 'source': "%s" % os.path.basename(file)}
|
1806 |
+
for s in get_sentences(blob, N) if s])
|
1807 |
+
personality = create_personality_data()
|
1808 |
+
import copy
|
1809 |
+
for i in range(10):
|
1810 |
+
row_list.extend(copy.deepcopy(personality))
|
1811 |
+
np.random.seed(123)
|
1812 |
+
np.random.shuffle(row_list)
|
1813 |
+
for i in range(len(row_list)):
|
1814 |
+
row_list[i]['id'] = i
|
1815 |
+
for i in range(len(row_list)):
|
1816 |
+
assert row_list[i]['id'] == i
|
1817 |
+
with open("h2ogpt-fortune2000-personalized.json", "w") as ff:
|
1818 |
+
ff.write(json.dumps(row_list, indent=2))
|
finetune.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
-
import time
|
4 |
from functools import partial
|
5 |
from typing import List, Union
|
6 |
-
from enum import Enum
|
7 |
import fire
|
8 |
import numpy as np
|
|
|
|
|
|
|
9 |
from utils import get_githash, copy_code
|
10 |
import torch
|
11 |
|
@@ -17,82 +18,6 @@ def log(*args, **kwargs):
|
|
17 |
print(*args, **kwargs)
|
18 |
|
19 |
|
20 |
-
class PromptType(Enum):
|
21 |
-
plain = 0
|
22 |
-
instruct = 1
|
23 |
-
quality = 2
|
24 |
-
human_bot = 3
|
25 |
-
dai_faq = 4
|
26 |
-
summarize = 5
|
27 |
-
simple_instruct = 6
|
28 |
-
instruct_vicuna = 7
|
29 |
-
instruct_with_end = 8
|
30 |
-
human_bot_orig = 9
|
31 |
-
prompt_answer = 10
|
32 |
-
open_assistant = 11
|
33 |
-
wizard_lm = 12
|
34 |
-
|
35 |
-
|
36 |
-
prompt_type_to_model_name = {
|
37 |
-
'plain': [
|
38 |
-
'EleutherAI/gpt-j-6B',
|
39 |
-
'EleutherAI/pythia-6.9b',
|
40 |
-
'EleutherAI/pythia-12b',
|
41 |
-
'EleutherAI/pythia-12b-deduped',
|
42 |
-
'EleutherAI/gpt-neox-20b',
|
43 |
-
'decapoda-research/llama-7b-hf',
|
44 |
-
'decapoda-research/llama-13b-hf',
|
45 |
-
'decapoda-research/llama-30b-hf',
|
46 |
-
'decapoda-research/llama-65b-hf',
|
47 |
-
'facebook/mbart-large-50-many-to-many-mmt',
|
48 |
-
'philschmid/bart-large-cnn-samsum',
|
49 |
-
'philschmid/flan-t5-base-samsum',
|
50 |
-
'gpt2',
|
51 |
-
'distilgpt2',
|
52 |
-
'mosaicml/mpt-7b-storywriter',
|
53 |
-
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
54 |
-
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
55 |
-
],
|
56 |
-
'prompt_answer': [
|
57 |
-
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
58 |
-
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
59 |
-
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
60 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
61 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
62 |
-
],
|
63 |
-
'instruct': [],
|
64 |
-
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
65 |
-
'quality': [],
|
66 |
-
'human_bot': [
|
67 |
-
'h2oai/h2ogpt-oasst1-512-12b',
|
68 |
-
'h2oai/h2ogpt-oasst1-512-20b',
|
69 |
-
'h2oai/h2ogpt-oig-oasst1-512-20b',
|
70 |
-
'h2oai/h2ogpt-oig-oasst1-512-12b',
|
71 |
-
'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
72 |
-
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
73 |
-
],
|
74 |
-
'dai_faq': [],
|
75 |
-
'summarize': [],
|
76 |
-
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
77 |
-
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
|
78 |
-
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
79 |
-
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
80 |
-
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
81 |
-
}
|
82 |
-
|
83 |
-
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
84 |
-
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
85 |
-
|
86 |
-
prompt_types_strings = []
|
87 |
-
for p in PromptType:
|
88 |
-
prompt_types_strings.extend([p.name])
|
89 |
-
|
90 |
-
|
91 |
-
prompt_types = []
|
92 |
-
for p in PromptType:
|
93 |
-
prompt_types.extend([p.name, p.value, str(p.value)])
|
94 |
-
|
95 |
-
|
96 |
# supported by huggingface evaluate
|
97 |
supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
|
98 |
|
@@ -353,7 +278,7 @@ def train(
|
|
353 |
if os.path.exists(checkpoint_name):
|
354 |
log(f"Restarting from {checkpoint_name}")
|
355 |
adapters_weights = torch.load(checkpoint_name)
|
356 |
-
|
357 |
else:
|
358 |
log(f"Checkpoint {checkpoint_name} not found")
|
359 |
|
@@ -656,58 +581,6 @@ def train(
|
|
656 |
log("\n If there's a warning about missing keys above, please disregard :)")
|
657 |
|
658 |
|
659 |
-
def get_loaders(llama_type, model_name, reward_type):
|
660 |
-
# NOTE: Some models need specific new prompt_type
|
661 |
-
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
662 |
-
if llama_type:
|
663 |
-
from transformers import LlamaForCausalLM, LlamaTokenizer
|
664 |
-
model_loader = LlamaForCausalLM
|
665 |
-
tokenizer_loader = LlamaTokenizer
|
666 |
-
elif 'distilgpt2' in model_name.lower():
|
667 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
668 |
-
return AutoModelForCausalLM, AutoTokenizer
|
669 |
-
elif 'gpt2' in model_name.lower():
|
670 |
-
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
671 |
-
return GPT2LMHeadModel, GPT2Tokenizer
|
672 |
-
elif 'mbart-' in model_name.lower():
|
673 |
-
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
674 |
-
return MBartForConditionalGeneration, MBart50TokenizerFast
|
675 |
-
elif 't5' == model_name.lower() or \
|
676 |
-
't5-' in model_name.lower() or \
|
677 |
-
'flan-' in model_name.lower():
|
678 |
-
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
679 |
-
return T5ForConditionalGeneration, AutoTokenizer
|
680 |
-
elif 'bigbird' in model_name:
|
681 |
-
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
682 |
-
return BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
683 |
-
elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
|
684 |
-
from transformers import pipeline
|
685 |
-
return pipeline, "summarization"
|
686 |
-
elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
|
687 |
-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
688 |
-
return AutoModelForSequenceClassification, AutoTokenizer
|
689 |
-
else:
|
690 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
691 |
-
model_loader = AutoModelForCausalLM
|
692 |
-
tokenizer_loader = AutoTokenizer
|
693 |
-
return model_loader, tokenizer_loader
|
694 |
-
|
695 |
-
|
696 |
-
def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
|
697 |
-
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
698 |
-
local_files_only=local_files_only,
|
699 |
-
resume_download=resume_download,
|
700 |
-
use_auth_token=use_auth_token)
|
701 |
-
|
702 |
-
tokenizer.pad_token_id = 0 # different from the eos token
|
703 |
-
# when generating, we will use the logits of right-most token to predict the next token
|
704 |
-
# so the padding should be on the left,
|
705 |
-
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
706 |
-
tokenizer.padding_side = "left" # Allow batched inference
|
707 |
-
|
708 |
-
return tokenizer
|
709 |
-
|
710 |
-
|
711 |
def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
|
712 |
# there's probably a way to do this with the tokenizer settings
|
713 |
# but again, gotta move fast
|
@@ -765,253 +638,6 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
|
|
765 |
return tokenized_full_prompt
|
766 |
|
767 |
|
768 |
-
def get_prompt(prompt_type, chat, context, reduced):
|
769 |
-
if prompt_type in [-1, "-1", "plain"]:
|
770 |
-
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
771 |
-
terminate_response = []
|
772 |
-
chat_sep = ''
|
773 |
-
elif prompt_type == 'simple_instruct':
|
774 |
-
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
775 |
-
terminate_response = []
|
776 |
-
chat_sep = '\n'
|
777 |
-
elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
|
778 |
-
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
779 |
-
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
780 |
-
|
781 |
-
PreInstruct = """
|
782 |
-
### Instruction:
|
783 |
-
"""
|
784 |
-
|
785 |
-
PreInput = """
|
786 |
-
### Input:
|
787 |
-
"""
|
788 |
-
|
789 |
-
PreResponse = """
|
790 |
-
### Response:
|
791 |
-
"""
|
792 |
-
if prompt_type in [7, "7", "instruct_with_end"]:
|
793 |
-
terminate_response = ['### End']
|
794 |
-
else:
|
795 |
-
terminate_response = None
|
796 |
-
chat_sep = '\n'
|
797 |
-
elif prompt_type in [1, "1", "quality"]:
|
798 |
-
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
|
799 |
-
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
|
800 |
-
|
801 |
-
PreInstruct = """
|
802 |
-
### Instruction:
|
803 |
-
"""
|
804 |
-
|
805 |
-
PreInput = """
|
806 |
-
### Input:
|
807 |
-
"""
|
808 |
-
|
809 |
-
PreResponse = """
|
810 |
-
### Response:
|
811 |
-
"""
|
812 |
-
terminate_response = None
|
813 |
-
chat_sep = '\n'
|
814 |
-
elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
|
815 |
-
human = '<human>:'
|
816 |
-
bot = "<bot>:"
|
817 |
-
if reduced or context or prompt_type in [2, "2", "human_bot"]:
|
818 |
-
preprompt = ''
|
819 |
-
else:
|
820 |
-
cur_date = time.strftime('%Y-%m-%d')
|
821 |
-
cur_time = time.strftime('%H:%M:%S %p %Z')
|
822 |
-
|
823 |
-
PRE_PROMPT = """\
|
824 |
-
Current Date: {}
|
825 |
-
Current Time: {}
|
826 |
-
|
827 |
-
"""
|
828 |
-
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
829 |
-
start = human
|
830 |
-
promptB = promptA = '%s%s ' % (preprompt, start)
|
831 |
-
|
832 |
-
PreInstruct = ""
|
833 |
-
|
834 |
-
PreInput = None
|
835 |
-
|
836 |
-
if reduced:
|
837 |
-
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
838 |
-
PreResponse = bot + ' '
|
839 |
-
else:
|
840 |
-
# normally LLM adds space after this, because was how trained.
|
841 |
-
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
842 |
-
PreResponse = bot
|
843 |
-
|
844 |
-
terminate_response = [start, PreResponse]
|
845 |
-
chat_sep = '\n'
|
846 |
-
elif prompt_type in [3, "3", "dai_faq"]:
|
847 |
-
promptA = ''
|
848 |
-
promptB = 'Answer the following Driverless AI question.\n'
|
849 |
-
|
850 |
-
PreInstruct = """
|
851 |
-
### Driverless AI frequently asked question:
|
852 |
-
"""
|
853 |
-
|
854 |
-
PreInput = None
|
855 |
-
|
856 |
-
PreResponse = """
|
857 |
-
### Driverless AI documentation answer:
|
858 |
-
"""
|
859 |
-
terminate_response = ['\n\n']
|
860 |
-
chat_sep = terminate_response
|
861 |
-
elif prompt_type in [5, "5", "summarize"]:
|
862 |
-
promptA = promptB = PreInput = ''
|
863 |
-
PreInstruct = '## Main Text\n\n'
|
864 |
-
PreResponse = '\n\n## Summary\n\n'
|
865 |
-
terminate_response = None
|
866 |
-
chat_sep = '\n'
|
867 |
-
elif prompt_type in [6, "6", "instruct_vicuna"]:
|
868 |
-
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
869 |
-
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
|
870 |
-
|
871 |
-
PreInstruct = """
|
872 |
-
### Human:
|
873 |
-
"""
|
874 |
-
|
875 |
-
PreInput = None
|
876 |
-
|
877 |
-
PreResponse = """
|
878 |
-
### Assistant:
|
879 |
-
"""
|
880 |
-
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
881 |
-
chat_sep = '\n'
|
882 |
-
elif prompt_type in [10, "10", "prompt_answer"]:
|
883 |
-
preprompt = ''
|
884 |
-
prompt_tokens = "<|prompt|>"
|
885 |
-
answer_tokens = "<|answer|>"
|
886 |
-
start = prompt_tokens
|
887 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
888 |
-
PreInstruct = ""
|
889 |
-
PreInput = None
|
890 |
-
PreResponse = answer_tokens
|
891 |
-
eos = '<|endoftext|>' # neox eos
|
892 |
-
terminate_response = [start, PreResponse, eos]
|
893 |
-
chat_sep = eos
|
894 |
-
elif prompt_type in [11, "11", "open_assistant"]:
|
895 |
-
# From added_tokens.json
|
896 |
-
preprompt = ''
|
897 |
-
prompt_tokens = "<|prompter|>"
|
898 |
-
answer_tokens = "<|assistant|>"
|
899 |
-
start = prompt_tokens
|
900 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
901 |
-
PreInstruct = ""
|
902 |
-
PreInput = None
|
903 |
-
PreResponse = answer_tokens
|
904 |
-
pend = "<|prefix_end|>"
|
905 |
-
eos = "</s>"
|
906 |
-
terminate_response = [start, PreResponse, pend, eos]
|
907 |
-
chat_sep = eos
|
908 |
-
elif prompt_type in [12, "12", "wizard_lm"]:
|
909 |
-
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
910 |
-
preprompt = ''
|
911 |
-
start = ''
|
912 |
-
promptB = promptA = '%s%s' % (preprompt, start)
|
913 |
-
PreInstruct = ""
|
914 |
-
PreInput = None
|
915 |
-
PreResponse = "\n\n### Response"
|
916 |
-
eos = "</s>"
|
917 |
-
terminate_response = [PreResponse, eos]
|
918 |
-
chat_sep = eos
|
919 |
-
else:
|
920 |
-
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
921 |
-
|
922 |
-
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep
|
923 |
-
|
924 |
-
|
925 |
-
def generate_prompt(data_point, prompt_type, chat, reduced):
|
926 |
-
context = data_point.get('context')
|
927 |
-
if context is None:
|
928 |
-
context = ''
|
929 |
-
instruction = data_point.get('instruction')
|
930 |
-
input = data_point.get('input')
|
931 |
-
output = data_point.get('output')
|
932 |
-
prompt_type = data_point.get('prompt_type', prompt_type)
|
933 |
-
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
934 |
-
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
935 |
-
terminate_response, chat_sep = get_prompt(prompt_type, chat, context, reduced)
|
936 |
-
|
937 |
-
prompt = context if not reduced else ''
|
938 |
-
|
939 |
-
if input and promptA:
|
940 |
-
prompt += f"""{promptA}"""
|
941 |
-
elif promptB:
|
942 |
-
prompt += f"""{promptB}"""
|
943 |
-
|
944 |
-
if instruction and PreInstruct is not None and input and PreInput is not None:
|
945 |
-
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
946 |
-
prompt = inject_newline(prompt_type, prompt)
|
947 |
-
elif instruction and input and PreInstruct is None and PreInput is not None:
|
948 |
-
prompt += f"""{PreInput}{instruction}
|
949 |
-
{input}"""
|
950 |
-
prompt = inject_newline(prompt_type, prompt)
|
951 |
-
elif input and instruction and PreInput is None and PreInstruct is not None:
|
952 |
-
prompt += f"""{PreInstruct}{instruction}
|
953 |
-
{input}"""
|
954 |
-
prompt = inject_newline(prompt_type, prompt)
|
955 |
-
elif instruction and PreInstruct is not None:
|
956 |
-
prompt += f"""{PreInstruct}{instruction}"""
|
957 |
-
prompt = inject_newline(prompt_type, prompt)
|
958 |
-
elif input and PreInput is not None:
|
959 |
-
prompt += f"""{PreInput}{input}"""
|
960 |
-
prompt = inject_newline(prompt_type, prompt)
|
961 |
-
elif input and instruction and PreInput is not None:
|
962 |
-
prompt += f"""{PreInput}{instruction}{input}"""
|
963 |
-
prompt = inject_newline(prompt_type, prompt)
|
964 |
-
elif input and instruction and PreInstruct is not None:
|
965 |
-
prompt += f"""{PreInstruct}{instruction}{input}"""
|
966 |
-
prompt = inject_newline(prompt_type, prompt)
|
967 |
-
elif input and instruction:
|
968 |
-
# i.e. for simple_instruct
|
969 |
-
prompt += f"""{instruction}: {input}"""
|
970 |
-
prompt = inject_newline(prompt_type, prompt)
|
971 |
-
elif input:
|
972 |
-
prompt += f"""{input}"""
|
973 |
-
prompt = inject_newline(prompt_type, prompt)
|
974 |
-
elif instruction:
|
975 |
-
prompt += f"""{instruction}"""
|
976 |
-
prompt = inject_newline(prompt_type, prompt)
|
977 |
-
|
978 |
-
if PreResponse is not None:
|
979 |
-
prompt += f"""{PreResponse}"""
|
980 |
-
pre_response = PreResponse # Don't use strip
|
981 |
-
else:
|
982 |
-
pre_response = ''
|
983 |
-
|
984 |
-
if output:
|
985 |
-
prompt += f"""{output}"""
|
986 |
-
|
987 |
-
return prompt, pre_response, terminate_response, chat_sep
|
988 |
-
|
989 |
-
|
990 |
-
def inject_newline(prompt_type, prompt):
|
991 |
-
if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
|
992 |
-
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
993 |
-
prompt += '\n'
|
994 |
-
return prompt
|
995 |
-
|
996 |
-
|
997 |
-
example_data_point0 = dict(instruction="Summarize",
|
998 |
-
input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
|
999 |
-
output="Ducks eat and swim at the lake.")
|
1000 |
-
|
1001 |
-
example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
|
1002 |
-
output="Einstein.")
|
1003 |
-
|
1004 |
-
example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
|
1005 |
-
output="Einstein.")
|
1006 |
-
|
1007 |
-
example_data_points = [example_data_point0, example_data_point1, example_data_point2]
|
1008 |
-
|
1009 |
-
|
1010 |
-
def test_train_prompt(prompt_type='instruct', data_point=0):
|
1011 |
-
example_data_point = example_data_points[data_point]
|
1012 |
-
return generate_prompt(example_data_point, prompt_type, False, False)
|
1013 |
-
|
1014 |
-
|
1015 |
def test_debug():
|
1016 |
fire.Fire(train)
|
1017 |
|
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
from functools import partial
|
4 |
from typing import List, Union
|
|
|
5 |
import fire
|
6 |
import numpy as np
|
7 |
+
|
8 |
+
from loaders import get_loaders, get_tokenizer
|
9 |
+
from prompter import generate_prompt, prompt_types
|
10 |
from utils import get_githash, copy_code
|
11 |
import torch
|
12 |
|
|
|
18 |
print(*args, **kwargs)
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# supported by huggingface evaluate
|
22 |
supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
|
23 |
|
|
|
278 |
if os.path.exists(checkpoint_name):
|
279 |
log(f"Restarting from {checkpoint_name}")
|
280 |
adapters_weights = torch.load(checkpoint_name)
|
281 |
+
set_peft_model_state_dict(model, adapters_weights)
|
282 |
else:
|
283 |
log(f"Checkpoint {checkpoint_name} not found")
|
284 |
|
|
|
581 |
log("\n If there's a warning about missing keys above, please disregard :)")
|
582 |
|
583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
|
585 |
# there's probably a way to do this with the tokenizer settings
|
586 |
# but again, gotta move fast
|
|
|
638 |
return tokenized_full_prompt
|
639 |
|
640 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
def test_debug():
|
642 |
fire.Fire(train)
|
643 |
|
generate.py
CHANGED
@@ -1,5 +1,9 @@
|
|
|
|
1 |
import functools
|
|
|
|
|
2 |
import queue
|
|
|
3 |
import sys
|
4 |
import os
|
5 |
import time
|
@@ -9,7 +13,12 @@ from datetime import datetime
|
|
9 |
import filelock
|
10 |
import psutil
|
11 |
|
12 |
-
from
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
SEED = 1236
|
15 |
set_seed(SEED)
|
@@ -25,13 +34,16 @@ from peft import PeftModel
|
|
25 |
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
26 |
from accelerate import init_empty_weights, infer_auto_device_map
|
27 |
|
28 |
-
from prompter import Prompter
|
29 |
-
|
30 |
-
from finetune import get_loaders, example_data_points, generate_prompt, inv_prompt_type_to_model_lower
|
31 |
from stopping import get_stopping
|
32 |
|
33 |
eval_extra_columns = ['prompt', 'response', 'score']
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def main(
|
37 |
load_8bit: bool = False,
|
@@ -63,6 +75,7 @@ def main(
|
|
63 |
resume_download: bool = True,
|
64 |
use_auth_token: Union[str, bool] = False,
|
65 |
trust_remote_code: Union[str, bool] = True,
|
|
|
66 |
|
67 |
src_lang: str = "English",
|
68 |
tgt_lang: str = "Russian",
|
@@ -70,7 +83,6 @@ def main(
|
|
70 |
gradio: bool = True,
|
71 |
gradio_avoid_processing_markdown: bool = False,
|
72 |
chat: bool = True,
|
73 |
-
chat_history: int = 4096,
|
74 |
chat_context: bool = False,
|
75 |
stream_output: bool = True,
|
76 |
show_examples: bool = None,
|
@@ -98,6 +110,30 @@ def main(
|
|
98 |
eval_sharegpt_prompts_only: int = 0,
|
99 |
eval_sharegpt_prompts_only_seed: int = 1234,
|
100 |
eval_sharegpt_as_output: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
):
|
102 |
"""
|
103 |
|
@@ -127,12 +163,12 @@ def main(
|
|
127 |
:param resume_download: whether to resume downloads from HF for models
|
128 |
:param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
|
129 |
:param trust_remote_code: whether to use trust any code needed for HF model
|
|
|
130 |
:param src_lang: source languages to include if doing translation (None = all)
|
131 |
:param tgt_lang: target languages to include if doing translation (None = all)
|
132 |
:param gradio: whether to enable gradio, or to enable benchmark mode
|
133 |
:param gradio_avoid_processing_markdown:
|
134 |
:param chat: whether to enable chat mode with chat history
|
135 |
-
:param chat_history: maximum character length of chat context/history
|
136 |
:param chat_context: whether to use extra helpful context if human_bot
|
137 |
:param stream_output: whether to stream output from generate
|
138 |
:param show_examples: whether to show clickable examples in gradio
|
@@ -157,6 +193,41 @@ def main(
|
|
157 |
:param eval_sharegpt_prompts_only: for no gradio benchmark, if using ShareGPT prompts for eval
|
158 |
:param eval_sharegpt_prompts_only_seed: for no gradio benchmark, if seed for ShareGPT sampling
|
159 |
:param eval_sharegpt_as_output: for no gradio benchmark, whether to test ShareGPT output itself
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
:return:
|
161 |
"""
|
162 |
is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
|
@@ -170,8 +241,20 @@ def main(
|
|
170 |
|
171 |
# allow set token directly
|
172 |
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
if is_public:
|
|
|
175 |
input_lines = 1 # ensure set, for ease of use
|
176 |
temperature = 0.2 if temperature is None else temperature
|
177 |
top_p = 0.85 if top_p is None else top_p
|
@@ -211,7 +294,7 @@ def main(
|
|
211 |
torch.backends.cudnn.benchmark = True
|
212 |
torch.backends.cudnn.enabled = False
|
213 |
torch.set_default_dtype(torch.float32)
|
214 |
-
if psutil.virtual_memory().available < 94*1024**3:
|
215 |
# 12B uses ~94GB
|
216 |
# 6.9B uses ~47GB
|
217 |
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b' if not base_model else base_model
|
@@ -223,16 +306,22 @@ def main(
|
|
223 |
stream_output = False
|
224 |
# else prompt removal can mess up output
|
225 |
chat = False
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
placeholder_instruction, placeholder_input, \
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
get_generate_params(model_lower, chat,
|
237 |
stream_output, show_examples,
|
238 |
prompt_type, temperature, top_p, top_k, num_beams,
|
@@ -246,6 +335,38 @@ def main(
|
|
246 |
print(f"Generating model with params:\n{locals_print}", flush=True)
|
247 |
print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
if not gradio:
|
250 |
if eval_sharegpt_prompts_only > 0:
|
251 |
# override default examples with shareGPT ones for human-level eval purposes only
|
@@ -309,11 +430,9 @@ def main(
|
|
309 |
if not eval_sharegpt_as_output:
|
310 |
model, tokenizer, device = get_model(**locals())
|
311 |
model_state = [model, tokenizer, device, base_model]
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
concurrency_count=concurrency_count,
|
316 |
-
lora_weights=lora_weights)
|
317 |
else:
|
318 |
assert eval_sharegpt_prompts_only > 0
|
319 |
|
@@ -325,8 +444,6 @@ def main(
|
|
325 |
t0 = time.time()
|
326 |
score_dump = []
|
327 |
|
328 |
-
import matplotlib.pyplot as plt
|
329 |
-
|
330 |
for exi, ex in enumerate(examples):
|
331 |
instruction = ex[eval_func_param_names.index('instruction_nochat')]
|
332 |
iinput = ex[eval_func_param_names.index('iinput_nochat')]
|
@@ -363,7 +480,8 @@ def main(
|
|
363 |
try:
|
364 |
score = torch.sigmoid(smodel(**inputs).logits[0].float()).cpu().detach().numpy()[0]
|
365 |
except torch.cuda.OutOfMemoryError as e:
|
366 |
-
print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
|
|
367 |
traceback.print_exc()
|
368 |
score = 0.0
|
369 |
clear_torch_cache()
|
@@ -419,22 +537,23 @@ def main(
|
|
419 |
smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
|
420 |
score_model_state0 = [smodel, stokenizer, sdevice, score_model]
|
421 |
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
|
431 |
-
|
432 |
|
433 |
|
434 |
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
435 |
gpu_id=0,
|
436 |
use_auth_token=False,
|
437 |
trust_remote_code=True,
|
|
|
438 |
triton_attn=False,
|
439 |
long_sequence=True,
|
440 |
):
|
@@ -448,6 +567,7 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
448 |
:param gpu_id:
|
449 |
:param use_auth_token:
|
450 |
:param trust_remote_code:
|
|
|
451 |
:param triton_attn:
|
452 |
:param long_sequence:
|
453 |
:return:
|
@@ -455,7 +575,8 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
455 |
with init_empty_weights():
|
456 |
from transformers import AutoConfig
|
457 |
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
|
458 |
-
trust_remote_code=trust_remote_code
|
|
|
459 |
if triton_attn and 'mpt-' in base_model.lower():
|
460 |
config.attn_config['attn_impl'] = 'triton'
|
461 |
if long_sequence:
|
@@ -485,7 +606,6 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
485 |
dtype=torch.float16 if load_half else torch.float32,
|
486 |
)
|
487 |
device_map.update(device_map_model)
|
488 |
-
print('device_map: %s' % device_map, flush=True)
|
489 |
else:
|
490 |
device_map = "auto"
|
491 |
|
@@ -504,6 +624,7 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
504 |
else:
|
505 |
device_map = {'': 'cpu'}
|
506 |
model_kwargs['load_in_8bit'] = False
|
|
|
507 |
|
508 |
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
509 |
model_kwargs['device_map'] = device_map
|
@@ -537,6 +658,7 @@ def get_model(
|
|
537 |
resume_download: bool = True,
|
538 |
use_auth_token: Union[str, bool] = False,
|
539 |
trust_remote_code: bool = True,
|
|
|
540 |
compile: bool = True,
|
541 |
**kwargs,
|
542 |
):
|
@@ -556,11 +678,17 @@ def get_model(
|
|
556 |
:param resume_download: resume downloads from HF
|
557 |
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
558 |
:param trust_remote_code: trust code needed by model
|
|
|
559 |
:param compile: whether to compile torch model
|
560 |
:param kwargs:
|
561 |
:return:
|
562 |
"""
|
563 |
print("Get %s model" % base_model, flush=True)
|
|
|
|
|
|
|
|
|
|
|
564 |
if lora_weights is not None and lora_weights.strip():
|
565 |
print("Get %s lora weights" % lora_weights, flush=True)
|
566 |
device = get_device()
|
@@ -575,7 +703,8 @@ def get_model(
|
|
575 |
|
576 |
from transformers import AutoConfig
|
577 |
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
|
578 |
-
trust_remote_code=trust_remote_code
|
|
|
579 |
llama_type_from_config = 'llama' in str(config).lower()
|
580 |
llama_type_from_name = "llama" in base_model.lower()
|
581 |
llama_type = llama_type_from_config or llama_type_from_name
|
@@ -593,6 +722,7 @@ def get_model(
|
|
593 |
resume_download=resume_download,
|
594 |
use_auth_token=use_auth_token,
|
595 |
trust_remote_code=trust_remote_code,
|
|
|
596 |
)
|
597 |
else:
|
598 |
tokenizer = tokenizer_loader
|
@@ -610,6 +740,7 @@ def get_model(
|
|
610 |
resume_download=resume_download,
|
611 |
use_auth_token=use_auth_token,
|
612 |
trust_remote_code=trust_remote_code,
|
|
|
613 |
)
|
614 |
if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
|
615 |
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
@@ -630,6 +761,7 @@ def get_model(
|
|
630 |
gpu_id=gpu_id,
|
631 |
use_auth_token=use_auth_token,
|
632 |
trust_remote_code=trust_remote_code,
|
|
|
633 |
)
|
634 |
else:
|
635 |
if load_half and not load_8bit:
|
@@ -653,6 +785,7 @@ def get_model(
|
|
653 |
resume_download=resume_download,
|
654 |
use_auth_token=use_auth_token,
|
655 |
trust_remote_code=trust_remote_code,
|
|
|
656 |
device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
|
657 |
)
|
658 |
else:
|
@@ -669,6 +802,7 @@ def get_model(
|
|
669 |
resume_download=resume_download,
|
670 |
use_auth_token=use_auth_token,
|
671 |
trust_remote_code=trust_remote_code,
|
|
|
672 |
device_map="auto",
|
673 |
)
|
674 |
if load_half:
|
@@ -729,11 +863,13 @@ eval_func_param_names = ['instruction',
|
|
729 |
'chat',
|
730 |
'instruction_nochat',
|
731 |
'iinput_nochat',
|
|
|
732 |
]
|
733 |
|
734 |
|
735 |
def evaluate(
|
736 |
model_state,
|
|
|
737 |
# START NOTE: Examples must have same order of parameters
|
738 |
instruction,
|
739 |
iinput,
|
@@ -754,6 +890,7 @@ def evaluate(
|
|
754 |
chat,
|
755 |
instruction_nochat,
|
756 |
iinput_nochat,
|
|
|
757 |
# END NOTE: Examples must have same order of parameters
|
758 |
src_lang=None,
|
759 |
tgt_lang=None,
|
@@ -766,12 +903,34 @@ def evaluate(
|
|
766 |
raise_generate_gpu_exceptions=None,
|
767 |
chat_context=None,
|
768 |
lora_weights=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
):
|
770 |
# ensure passed these
|
771 |
assert concurrency_count is not None
|
772 |
assert is_low_mem is not None
|
773 |
assert raise_generate_gpu_exceptions is not None
|
774 |
assert chat_context is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
775 |
|
776 |
if debug:
|
777 |
locals_dict = locals().copy()
|
@@ -817,10 +976,58 @@ def evaluate(
|
|
817 |
# get hidden context if have one
|
818 |
context = get_context(chat_context, prompt_type)
|
819 |
|
820 |
-
data_point = dict(context=context, instruction=instruction, input=iinput)
|
821 |
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
|
|
822 |
prompt = prompter.generate_prompt(data_point)
|
823 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
824 |
if isinstance(tokenizer, str):
|
825 |
# pipeline
|
826 |
if tokenizer == "summarization":
|
@@ -838,18 +1045,14 @@ def evaluate(
|
|
838 |
# override, ignore user change
|
839 |
num_return_sequences = 1
|
840 |
stopping_criteria = get_stopping(prompt_type, tokenizer, device)
|
841 |
-
|
842 |
-
|
843 |
-
# RuntimeError: expected scalar type Half but found Float
|
844 |
-
# with - 256
|
845 |
-
max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
|
846 |
-
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
847 |
-
output_smallest = 30 * 4
|
848 |
-
prompt = prompt[-cutoff_len - output_smallest:]
|
849 |
inputs = tokenizer(prompt,
|
850 |
return_tensors="pt",
|
851 |
truncation=True,
|
852 |
max_length=max_length_tokenize)
|
|
|
|
|
853 |
if debug and len(inputs["input_ids"]) > 0:
|
854 |
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
855 |
input_ids = inputs["input_ids"].to(device)
|
@@ -891,7 +1094,7 @@ def evaluate(
|
|
891 |
**decoder_kwargs
|
892 |
)
|
893 |
decoder_raw_kwargs = dict(skip_special_tokens=False,
|
894 |
-
|
895 |
|
896 |
decoder_raw = functools.partial(tokenizer.decode,
|
897 |
**decoder_raw_kwargs
|
@@ -904,7 +1107,7 @@ def evaluate(
|
|
904 |
# else hit bitsandbytes lack of thread safety:
|
905 |
# https://github.com/h2oai/h2ogpt/issues/104
|
906 |
# but only makes sense if concurrency_count == 1
|
907 |
-
context_class = NullContext
|
908 |
print('Pre-Generate: %s' % str(datetime.now()), flush=True)
|
909 |
decoded_output = None
|
910 |
with context_class("generate.lock"):
|
@@ -923,7 +1126,9 @@ def evaluate(
|
|
923 |
inputs_decoded = prompt = inputs_decoded_raw
|
924 |
decoder = decoder_raw
|
925 |
decoder_kwargs = decoder_raw_kwargs
|
926 |
-
elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ',
|
|
|
|
|
927 |
inputs_decoded = prompt = inputs_decoded_raw
|
928 |
decoder = decoder_raw
|
929 |
decoder_kwargs = decoder_raw_kwargs
|
@@ -931,13 +1136,15 @@ def evaluate(
|
|
931 |
print("WARNING: Special characters in prompt", flush=True)
|
932 |
if stream_output:
|
933 |
skip_prompt = False
|
934 |
-
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
|
|
|
935 |
gen_kwargs.update(dict(streamer=streamer))
|
936 |
-
|
937 |
-
|
938 |
-
raise_generate_gpu_exceptions,
|
|
|
939 |
bucket = queue.Queue()
|
940 |
-
thread = EThread(target=target,
|
941 |
thread.start()
|
942 |
outputs = ""
|
943 |
try:
|
@@ -969,7 +1176,30 @@ def evaluate(
|
|
969 |
decoded_output = prompt + outputs[0]
|
970 |
if save_dir and decoded_output:
|
971 |
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
972 |
-
print('Post-Generate: %s decoded_output: %s' % (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
973 |
|
974 |
|
975 |
class H2OTextIteratorStreamer(TextIteratorStreamer):
|
@@ -977,6 +1207,7 @@ class H2OTextIteratorStreamer(TextIteratorStreamer):
|
|
977 |
normally, timeout required for now to handle exceptions, else get()
|
978 |
but with H2O version of TextIteratorStreamer, loop over block to handle
|
979 |
"""
|
|
|
980 |
def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None,
|
981 |
block=True, **decode_kwargs):
|
982 |
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
@@ -1003,7 +1234,7 @@ class H2OTextIteratorStreamer(TextIteratorStreamer):
|
|
1003 |
print("hit stop", flush=True)
|
1004 |
# could raise or break, maybe best to raise and make parent see if any exception in thread
|
1005 |
raise StopIteration()
|
1006 |
-
#break
|
1007 |
value = self.text_queue.get(block=self.block, timeout=self.timeout)
|
1008 |
break
|
1009 |
except queue.Empty:
|
@@ -1014,15 +1245,16 @@ class H2OTextIteratorStreamer(TextIteratorStreamer):
|
|
1014 |
return value
|
1015 |
|
1016 |
|
1017 |
-
def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
|
1018 |
try:
|
1019 |
-
func(**kwargs)
|
1020 |
except torch.cuda.OutOfMemoryError as e:
|
1021 |
print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
1022 |
flush=True)
|
1023 |
-
if
|
1024 |
-
kwargs['input_ids']
|
1025 |
-
|
|
|
1026 |
traceback.print_exc()
|
1027 |
clear_torch_cache()
|
1028 |
return
|
@@ -1214,7 +1446,7 @@ y = np.random.randint(0, 1, 100)
|
|
1214 |
|
1215 |
# move to correct position
|
1216 |
for example in examples:
|
1217 |
-
example += [chat, '', '']
|
1218 |
# adjust examples if non-chat mode
|
1219 |
if not chat:
|
1220 |
example[eval_func_param_names.index('instruction_nochat')] = example[
|
@@ -1223,16 +1455,18 @@ y = np.random.randint(0, 1, 100)
|
|
1223 |
|
1224 |
example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
|
1225 |
example[eval_func_param_names.index('iinput')] = ''
|
|
|
|
|
1226 |
|
1227 |
return placeholder_instruction, placeholder_input, \
|
1228 |
-
|
1229 |
-
|
1230 |
-
|
1231 |
-
|
1232 |
-
|
1233 |
-
|
1234 |
-
|
1235 |
-
|
1236 |
|
1237 |
|
1238 |
def languages_covered():
|
@@ -1252,12 +1486,6 @@ def get_context(chat_context, prompt_type):
|
|
1252 |
return context0
|
1253 |
|
1254 |
|
1255 |
-
def test_test_prompt(prompt_type='instruct', data_point=0):
|
1256 |
-
example_data_point = example_data_points[data_point]
|
1257 |
-
example_data_point.pop('output', None)
|
1258 |
-
return generate_prompt(example_data_point, prompt_type, False, False)
|
1259 |
-
|
1260 |
-
|
1261 |
def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len):
|
1262 |
question = question[-cutoff_len:]
|
1263 |
answer = answer[-cutoff_len:]
|
@@ -1321,39 +1549,3 @@ if __name__ == "__main__":
|
|
1321 |
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
|
1322 |
"""
|
1323 |
fire.Fire(main)
|
1324 |
-
|
1325 |
-
|
1326 |
-
import pytest
|
1327 |
-
|
1328 |
-
@pytest.mark.parametrize(
|
1329 |
-
"base_model",
|
1330 |
-
[
|
1331 |
-
"h2oai/h2ogpt-oig-oasst1-512-6.9b",
|
1332 |
-
"h2oai/h2ogpt-oig-oasst1-512-12b",
|
1333 |
-
"h2oai/h2ogpt-oig-oasst1-512-20b",
|
1334 |
-
"h2oai/h2ogpt-oasst1-512-12b",
|
1335 |
-
"h2oai/h2ogpt-oasst1-512-20b",
|
1336 |
-
"h2oai/h2ogpt-gm-oasst1-en-1024-20b",
|
1337 |
-
"databricks/dolly-v2-12b",
|
1338 |
-
"h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2",
|
1339 |
-
"ehartford/WizardLM-7B-Uncensored",
|
1340 |
-
"ehartford/WizardLM-13B-Uncensored",
|
1341 |
-
"AlekseyKorshuk/vicuna-7b",
|
1342 |
-
"TheBloke/stable-vicuna-13B-HF",
|
1343 |
-
"decapoda-research/llama-7b-hf",
|
1344 |
-
"decapoda-research/llama-13b-hf",
|
1345 |
-
"decapoda-research/llama-30b-hf",
|
1346 |
-
"junelee/wizard-vicuna-13b",
|
1347 |
-
]
|
1348 |
-
)
|
1349 |
-
def test_score_eval(base_model):
|
1350 |
-
main(
|
1351 |
-
base_model=base_model,
|
1352 |
-
chat=False,
|
1353 |
-
stream_output=False,
|
1354 |
-
gradio=False,
|
1355 |
-
eval_sharegpt_prompts_only=500,
|
1356 |
-
eval_sharegpt_as_output=False,
|
1357 |
-
num_beams=2,
|
1358 |
-
infer_devices=False,
|
1359 |
-
)
|
|
|
1 |
+
import ast
|
2 |
import functools
|
3 |
+
import glob
|
4 |
+
import inspect
|
5 |
import queue
|
6 |
+
import shutil
|
7 |
import sys
|
8 |
import os
|
9 |
import time
|
|
|
13 |
import filelock
|
14 |
import psutil
|
15 |
|
16 |
+
from loaders import get_loaders
|
17 |
+
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
18 |
+
import_matplotlib, get_device, makedirs
|
19 |
+
|
20 |
+
import_matplotlib()
|
21 |
+
from matplotlib import pyplot as plt
|
22 |
|
23 |
SEED = 1236
|
24 |
set_seed(SEED)
|
|
|
34 |
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
35 |
from accelerate import init_empty_weights, infer_auto_device_map
|
36 |
|
37 |
+
from prompter import Prompter, inv_prompt_type_to_model_lower
|
|
|
|
|
38 |
from stopping import get_stopping
|
39 |
|
40 |
eval_extra_columns = ['prompt', 'response', 'score']
|
41 |
|
42 |
+
langchain_modes = ['Disabled', 'ChatLLM', 'LLM', 'All', 'wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT',
|
43 |
+
'DriverlessAI docs']
|
44 |
+
|
45 |
+
scratch_base_dir = '/tmp/'
|
46 |
+
|
47 |
|
48 |
def main(
|
49 |
load_8bit: bool = False,
|
|
|
75 |
resume_download: bool = True,
|
76 |
use_auth_token: Union[str, bool] = False,
|
77 |
trust_remote_code: Union[str, bool] = True,
|
78 |
+
offload_folder: str = "offline_folder",
|
79 |
|
80 |
src_lang: str = "English",
|
81 |
tgt_lang: str = "Russian",
|
|
|
83 |
gradio: bool = True,
|
84 |
gradio_avoid_processing_markdown: bool = False,
|
85 |
chat: bool = True,
|
|
|
86 |
chat_context: bool = False,
|
87 |
stream_output: bool = True,
|
88 |
show_examples: bool = None,
|
|
|
110 |
eval_sharegpt_prompts_only: int = 0,
|
111 |
eval_sharegpt_prompts_only_seed: int = 1234,
|
112 |
eval_sharegpt_as_output: bool = False,
|
113 |
+
|
114 |
+
langchain_mode: str = 'Disabled',
|
115 |
+
visible_langchain_modes: list = ['UserData', 'MyData'],
|
116 |
+
user_path: str = None,
|
117 |
+
load_db_if_exists: bool = True,
|
118 |
+
keep_sources_in_context: bool = False,
|
119 |
+
db_type: str = 'chroma',
|
120 |
+
use_openai_embedding: bool = False,
|
121 |
+
use_openai_model: bool = False,
|
122 |
+
hf_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
123 |
+
allow_upload_to_user_data: bool = True,
|
124 |
+
allow_upload_to_my_data: bool = True,
|
125 |
+
enable_url_upload: bool = True,
|
126 |
+
enable_text_upload: bool = True,
|
127 |
+
enable_sources_list: bool = True,
|
128 |
+
chunk: bool = True,
|
129 |
+
chunk_size: int = 512,
|
130 |
+
k: int = 4,
|
131 |
+
n_jobs: int = -1,
|
132 |
+
enable_captions: bool = True,
|
133 |
+
captions_model: str = "Salesforce/blip-image-captioning-base",
|
134 |
+
pre_load_caption_model: bool = False,
|
135 |
+
caption_gpu: bool = True,
|
136 |
+
enable_ocr: bool = False,
|
137 |
):
|
138 |
"""
|
139 |
|
|
|
163 |
:param resume_download: whether to resume downloads from HF for models
|
164 |
:param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
|
165 |
:param trust_remote_code: whether to use trust any code needed for HF model
|
166 |
+
:param offload_folder: path for spilling model onto disk
|
167 |
:param src_lang: source languages to include if doing translation (None = all)
|
168 |
:param tgt_lang: target languages to include if doing translation (None = all)
|
169 |
:param gradio: whether to enable gradio, or to enable benchmark mode
|
170 |
:param gradio_avoid_processing_markdown:
|
171 |
:param chat: whether to enable chat mode with chat history
|
|
|
172 |
:param chat_context: whether to use extra helpful context if human_bot
|
173 |
:param stream_output: whether to stream output from generate
|
174 |
:param show_examples: whether to show clickable examples in gradio
|
|
|
193 |
:param eval_sharegpt_prompts_only: for no gradio benchmark, if using ShareGPT prompts for eval
|
194 |
:param eval_sharegpt_prompts_only_seed: for no gradio benchmark, if seed for ShareGPT sampling
|
195 |
:param eval_sharegpt_as_output: for no gradio benchmark, whether to test ShareGPT output itself
|
196 |
+
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
197 |
+
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
198 |
+
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode
|
199 |
+
:param visible_langchain_modes: dbs to generate at launch to be ready for LLM
|
200 |
+
Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
|
201 |
+
But wiki_full is expensive and requires preparation
|
202 |
+
To allow scratch space only live in session, add 'MyData' to list
|
203 |
+
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
204 |
+
FIXME: Avoid 'All' for now, not implemented
|
205 |
+
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
206 |
+
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
207 |
+
:param db_type: 'faiss' for in-memory or 'chroma' for persisted on disk
|
208 |
+
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
|
209 |
+
:param use_openai_model: Whether to use OpenAI model for use with vector db
|
210 |
+
:param hf_embedding_model: Which HF embedding model to use for vector db
|
211 |
+
:param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
|
212 |
+
:param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
|
213 |
+
:param enable_url_upload: Whether to allow upload from URL
|
214 |
+
:param enable_text_upload: Whether to allow uplaod of text
|
215 |
+
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
|
216 |
+
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
217 |
+
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
|
218 |
+
:param k: number of chunks to give LLM
|
219 |
+
:param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
|
220 |
+
:param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
|
221 |
+
:param captions_model: Which model to use for captions.
|
222 |
+
captions_model: int = "Salesforce/blip-image-captioning-base", # continue capable
|
223 |
+
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
|
224 |
+
captions_model: int = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
|
225 |
+
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
|
226 |
+
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
|
227 |
+
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
|
228 |
+
Recommended if using larger caption model
|
229 |
+
:param caption_gpu: If support caption, then use GPU if exists
|
230 |
+
:param enable_ocr: Whether to support OCR on images
|
231 |
:return:
|
232 |
"""
|
233 |
is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
|
|
|
241 |
|
242 |
# allow set token directly
|
243 |
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
244 |
+
allow_upload_to_user_data = bool(os.environ.get("allow_upload_to_user_data", allow_upload_to_user_data))
|
245 |
+
allow_upload_to_my_data = bool(os.environ.get("allow_upload_to_my_data", allow_upload_to_my_data))
|
246 |
+
height = os.environ.get("HEIGHT", height)
|
247 |
+
|
248 |
+
# allow enabling langchain via ENV
|
249 |
+
# FIRST PLACE where LangChain referenced, but no imports related to it
|
250 |
+
langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
|
251 |
+
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
|
252 |
+
visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes)))
|
253 |
+
if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
|
254 |
+
visible_langchain_modes += [langchain_mode]
|
255 |
|
256 |
if is_public:
|
257 |
+
allow_upload_to_user_data = False
|
258 |
input_lines = 1 # ensure set, for ease of use
|
259 |
temperature = 0.2 if temperature is None else temperature
|
260 |
top_p = 0.85 if top_p is None else top_p
|
|
|
294 |
torch.backends.cudnn.benchmark = True
|
295 |
torch.backends.cudnn.enabled = False
|
296 |
torch.set_default_dtype(torch.float32)
|
297 |
+
if psutil.virtual_memory().available < 94 * 1024 ** 3:
|
298 |
# 12B uses ~94GB
|
299 |
# 6.9B uses ~47GB
|
300 |
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b' if not base_model else base_model
|
|
|
306 |
stream_output = False
|
307 |
# else prompt removal can mess up output
|
308 |
chat = False
|
309 |
+
# hard-coded defaults
|
310 |
+
first_para = False
|
311 |
+
text_limit = None
|
312 |
+
|
313 |
+
if offload_folder:
|
314 |
+
makedirs(offload_folder)
|
315 |
|
316 |
placeholder_instruction, placeholder_input, \
|
317 |
+
stream_output, show_examples, \
|
318 |
+
prompt_type, temperature, top_p, top_k, num_beams, \
|
319 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time, \
|
320 |
+
repetition_penalty, num_return_sequences, \
|
321 |
+
do_sample, \
|
322 |
+
src_lang, tgt_lang, \
|
323 |
+
examples, \
|
324 |
+
task_info = \
|
325 |
get_generate_params(model_lower, chat,
|
326 |
stream_output, show_examples,
|
327 |
prompt_type, temperature, top_p, top_k, num_beams,
|
|
|
335 |
print(f"Generating model with params:\n{locals_print}", flush=True)
|
336 |
print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
|
337 |
|
338 |
+
if langchain_mode != "Disabled":
|
339 |
+
# SECOND PLACE where LangChain referenced, but all imports are kept local so not required
|
340 |
+
from gpt_langchain import prep_langchain, get_some_dbs_from_hf
|
341 |
+
if is_hf:
|
342 |
+
get_some_dbs_from_hf()
|
343 |
+
dbs = {}
|
344 |
+
for langchain_mode1 in visible_langchain_modes:
|
345 |
+
if langchain_mode1 in ['MyData']:
|
346 |
+
# don't use what is on disk, remove it instead
|
347 |
+
for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
|
348 |
+
if os.path.isdir(gpath1):
|
349 |
+
print("Removing old MyData: %s" % gpath1, flush=True)
|
350 |
+
shutil.rmtree(gpath1)
|
351 |
+
continue
|
352 |
+
if langchain_mode1 in ['All']:
|
353 |
+
# FIXME: All should be avoided until scans over each db, shouldn't be separate db
|
354 |
+
continue
|
355 |
+
persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case
|
356 |
+
db = prep_langchain(persist_directory1, load_db_if_exists, db_type, use_openai_embedding,
|
357 |
+
langchain_mode1, user_path,
|
358 |
+
hf_embedding_model,
|
359 |
+
kwargs_make_db=locals())
|
360 |
+
dbs[langchain_mode1] = db
|
361 |
+
# remove None db's so can just rely upon k in dbs for if hav db
|
362 |
+
dbs = {k: v for k, v in dbs.items() if v is not None}
|
363 |
+
else:
|
364 |
+
dbs = {}
|
365 |
+
# import control
|
366 |
+
if os.environ.get("TEST_LANGCHAIN_IMPORT"):
|
367 |
+
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
368 |
+
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
369 |
+
|
370 |
if not gradio:
|
371 |
if eval_sharegpt_prompts_only > 0:
|
372 |
# override default examples with shareGPT ones for human-level eval purposes only
|
|
|
430 |
if not eval_sharegpt_as_output:
|
431 |
model, tokenizer, device = get_model(**locals())
|
432 |
model_state = [model, tokenizer, device, base_model]
|
433 |
+
kwargs_evaluate = {k: v for k, v in locals().items() if k in inputs_kwargs_list}
|
434 |
+
my_db_state = [None]
|
435 |
+
fun = partial(evaluate, model_state, my_db_state, **kwargs_evaluate)
|
|
|
|
|
436 |
else:
|
437 |
assert eval_sharegpt_prompts_only > 0
|
438 |
|
|
|
444 |
t0 = time.time()
|
445 |
score_dump = []
|
446 |
|
|
|
|
|
447 |
for exi, ex in enumerate(examples):
|
448 |
instruction = ex[eval_func_param_names.index('instruction_nochat')]
|
449 |
iinput = ex[eval_func_param_names.index('iinput_nochat')]
|
|
|
480 |
try:
|
481 |
score = torch.sigmoid(smodel(**inputs).logits[0].float()).cpu().detach().numpy()[0]
|
482 |
except torch.cuda.OutOfMemoryError as e:
|
483 |
+
print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
484 |
+
flush=True)
|
485 |
traceback.print_exc()
|
486 |
score = 0.0
|
487 |
clear_torch_cache()
|
|
|
537 |
smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
|
538 |
score_model_state0 = [smodel, stokenizer, sdevice, score_model]
|
539 |
|
540 |
+
if enable_captions:
|
541 |
+
if pre_load_caption_model:
|
542 |
+
from image_captions import H2OImageCaptionLoader
|
543 |
+
caption_loader = H2OImageCaptionLoader(caption_gpu=caption_gpu).load_model()
|
544 |
+
else:
|
545 |
+
caption_loader = 'gpu' if caption_gpu else 'cpu'
|
546 |
+
else:
|
547 |
+
caption_loader = False
|
548 |
|
549 |
+
go_gradio(**locals())
|
550 |
|
551 |
|
552 |
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
553 |
gpu_id=0,
|
554 |
use_auth_token=False,
|
555 |
trust_remote_code=True,
|
556 |
+
offload_folder=None,
|
557 |
triton_attn=False,
|
558 |
long_sequence=True,
|
559 |
):
|
|
|
567 |
:param gpu_id:
|
568 |
:param use_auth_token:
|
569 |
:param trust_remote_code:
|
570 |
+
:param offload_folder:
|
571 |
:param triton_attn:
|
572 |
:param long_sequence:
|
573 |
:return:
|
|
|
575 |
with init_empty_weights():
|
576 |
from transformers import AutoConfig
|
577 |
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
|
578 |
+
trust_remote_code=trust_remote_code,
|
579 |
+
offload_folder=offload_folder)
|
580 |
if triton_attn and 'mpt-' in base_model.lower():
|
581 |
config.attn_config['attn_impl'] = 'triton'
|
582 |
if long_sequence:
|
|
|
606 |
dtype=torch.float16 if load_half else torch.float32,
|
607 |
)
|
608 |
device_map.update(device_map_model)
|
|
|
609 |
else:
|
610 |
device_map = "auto"
|
611 |
|
|
|
624 |
else:
|
625 |
device_map = {'': 'cpu'}
|
626 |
model_kwargs['load_in_8bit'] = False
|
627 |
+
print('device_map: %s' % device_map, flush=True)
|
628 |
|
629 |
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
630 |
model_kwargs['device_map'] = device_map
|
|
|
658 |
resume_download: bool = True,
|
659 |
use_auth_token: Union[str, bool] = False,
|
660 |
trust_remote_code: bool = True,
|
661 |
+
offload_folder: str = None,
|
662 |
compile: bool = True,
|
663 |
**kwargs,
|
664 |
):
|
|
|
678 |
:param resume_download: resume downloads from HF
|
679 |
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
680 |
:param trust_remote_code: trust code needed by model
|
681 |
+
:param offload_folder: offload folder
|
682 |
:param compile: whether to compile torch model
|
683 |
:param kwargs:
|
684 |
:return:
|
685 |
"""
|
686 |
print("Get %s model" % base_model, flush=True)
|
687 |
+
if base_model in ['llama', 'gptj']:
|
688 |
+
from gpt4all_llm import get_model_tokenizer_gpt4all
|
689 |
+
model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
|
690 |
+
return model, tokenizer, device
|
691 |
+
|
692 |
if lora_weights is not None and lora_weights.strip():
|
693 |
print("Get %s lora weights" % lora_weights, flush=True)
|
694 |
device = get_device()
|
|
|
703 |
|
704 |
from transformers import AutoConfig
|
705 |
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
|
706 |
+
trust_remote_code=trust_remote_code,
|
707 |
+
offload_folder=offload_folder)
|
708 |
llama_type_from_config = 'llama' in str(config).lower()
|
709 |
llama_type_from_name = "llama" in base_model.lower()
|
710 |
llama_type = llama_type_from_config or llama_type_from_name
|
|
|
722 |
resume_download=resume_download,
|
723 |
use_auth_token=use_auth_token,
|
724 |
trust_remote_code=trust_remote_code,
|
725 |
+
offload_folder=offload_folder,
|
726 |
)
|
727 |
else:
|
728 |
tokenizer = tokenizer_loader
|
|
|
740 |
resume_download=resume_download,
|
741 |
use_auth_token=use_auth_token,
|
742 |
trust_remote_code=trust_remote_code,
|
743 |
+
offload_folder=offload_folder,
|
744 |
)
|
745 |
if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
|
746 |
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
|
|
761 |
gpu_id=gpu_id,
|
762 |
use_auth_token=use_auth_token,
|
763 |
trust_remote_code=trust_remote_code,
|
764 |
+
offload_folder=offload_folder,
|
765 |
)
|
766 |
else:
|
767 |
if load_half and not load_8bit:
|
|
|
785 |
resume_download=resume_download,
|
786 |
use_auth_token=use_auth_token,
|
787 |
trust_remote_code=trust_remote_code,
|
788 |
+
offload_folder=offload_folder,
|
789 |
device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
|
790 |
)
|
791 |
else:
|
|
|
802 |
resume_download=resume_download,
|
803 |
use_auth_token=use_auth_token,
|
804 |
trust_remote_code=trust_remote_code,
|
805 |
+
offload_folder=offload_folder,
|
806 |
device_map="auto",
|
807 |
)
|
808 |
if load_half:
|
|
|
863 |
'chat',
|
864 |
'instruction_nochat',
|
865 |
'iinput_nochat',
|
866 |
+
'langchain_mode',
|
867 |
]
|
868 |
|
869 |
|
870 |
def evaluate(
|
871 |
model_state,
|
872 |
+
my_db_state,
|
873 |
# START NOTE: Examples must have same order of parameters
|
874 |
instruction,
|
875 |
iinput,
|
|
|
890 |
chat,
|
891 |
instruction_nochat,
|
892 |
iinput_nochat,
|
893 |
+
langchain_mode,
|
894 |
# END NOTE: Examples must have same order of parameters
|
895 |
src_lang=None,
|
896 |
tgt_lang=None,
|
|
|
903 |
raise_generate_gpu_exceptions=None,
|
904 |
chat_context=None,
|
905 |
lora_weights=None,
|
906 |
+
load_db_if_exists=True,
|
907 |
+
dbs=None,
|
908 |
+
user_path=None,
|
909 |
+
use_openai_embedding=None,
|
910 |
+
use_openai_model=None,
|
911 |
+
hf_embedding_model=None,
|
912 |
+
chunk=None,
|
913 |
+
chunk_size=None,
|
914 |
+
db_type=None,
|
915 |
+
k=None,
|
916 |
+
n_jobs=None,
|
917 |
+
first_para=None,
|
918 |
+
text_limit=None,
|
919 |
):
|
920 |
# ensure passed these
|
921 |
assert concurrency_count is not None
|
922 |
assert is_low_mem is not None
|
923 |
assert raise_generate_gpu_exceptions is not None
|
924 |
assert chat_context is not None
|
925 |
+
assert use_openai_embedding is not None
|
926 |
+
assert use_openai_model is not None
|
927 |
+
assert hf_embedding_model is not None
|
928 |
+
assert chunk is not None
|
929 |
+
assert chunk_size is not None
|
930 |
+
assert db_type is not None
|
931 |
+
assert k is not None
|
932 |
+
assert n_jobs is not None
|
933 |
+
assert first_para is not None
|
934 |
|
935 |
if debug:
|
936 |
locals_dict = locals().copy()
|
|
|
976 |
# get hidden context if have one
|
977 |
context = get_context(chat_context, prompt_type)
|
978 |
|
|
|
979 |
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
980 |
+
data_point = dict(context=context, instruction=instruction, input=iinput)
|
981 |
prompt = prompter.generate_prompt(data_point)
|
982 |
|
983 |
+
# THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
|
984 |
+
assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
|
985 |
+
if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
|
986 |
+
db1 = my_db_state[0]
|
987 |
+
elif dbs is not None and langchain_mode in dbs:
|
988 |
+
db1 = dbs[langchain_mode]
|
989 |
+
else:
|
990 |
+
db1 = None
|
991 |
+
if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in ['llama', 'gptj']:
|
992 |
+
query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
|
993 |
+
outr = ""
|
994 |
+
# use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
|
995 |
+
from gpt_langchain import run_qa_db
|
996 |
+
for r in run_qa_db(query=query,
|
997 |
+
model_name=base_model, model=model, tokenizer=tokenizer,
|
998 |
+
stream_output=stream_output,
|
999 |
+
prompter=prompter,
|
1000 |
+
load_db_if_exists=load_db_if_exists,
|
1001 |
+
db=db1,
|
1002 |
+
user_path=user_path,
|
1003 |
+
max_new_tokens=max_new_tokens,
|
1004 |
+
cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
|
1005 |
+
use_openai_embedding=use_openai_embedding,
|
1006 |
+
use_openai_model=use_openai_model,
|
1007 |
+
hf_embedding_model=hf_embedding_model,
|
1008 |
+
first_para=first_para,
|
1009 |
+
text_limit=text_limit,
|
1010 |
+
chunk=chunk,
|
1011 |
+
chunk_size=chunk_size,
|
1012 |
+
langchain_mode=langchain_mode,
|
1013 |
+
db_type=db_type,
|
1014 |
+
k=k,
|
1015 |
+
temperature=temperature,
|
1016 |
+
repetition_penalty=repetition_penalty,
|
1017 |
+
top_k=top_k,
|
1018 |
+
top_p=top_p,
|
1019 |
+
prompt_type=prompt_type,
|
1020 |
+
n_jobs=n_jobs,
|
1021 |
+
):
|
1022 |
+
outr = r # doesn't accumulate, new answer every yield, so only save that full answer
|
1023 |
+
yield r
|
1024 |
+
if save_dir:
|
1025 |
+
save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
|
1026 |
+
print('Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
|
1027 |
+
flush=True)
|
1028 |
+
if outr:
|
1029 |
+
return
|
1030 |
+
|
1031 |
if isinstance(tokenizer, str):
|
1032 |
# pipeline
|
1033 |
if tokenizer == "summarization":
|
|
|
1045 |
# override, ignore user change
|
1046 |
num_return_sequences = 1
|
1047 |
stopping_criteria = get_stopping(prompt_type, tokenizer, device)
|
1048 |
+
_, _, max_length_tokenize, max_prompt_length = get_cutoffs(is_low_mem)
|
1049 |
+
prompt = prompt[-max_prompt_length:]
|
|
|
|
|
|
|
|
|
|
|
|
|
1050 |
inputs = tokenizer(prompt,
|
1051 |
return_tensors="pt",
|
1052 |
truncation=True,
|
1053 |
max_length=max_length_tokenize)
|
1054 |
+
if inputs['input_ids'].shape[1] >= max_length_tokenize - 1:
|
1055 |
+
print("Cutting off input: %s %s" % (inputs['input_ids'].shape[1], max_length_tokenize), flush=True)
|
1056 |
if debug and len(inputs["input_ids"]) > 0:
|
1057 |
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
1058 |
input_ids = inputs["input_ids"].to(device)
|
|
|
1094 |
**decoder_kwargs
|
1095 |
)
|
1096 |
decoder_raw_kwargs = dict(skip_special_tokens=False,
|
1097 |
+
clean_up_tokenization_spaces=True)
|
1098 |
|
1099 |
decoder_raw = functools.partial(tokenizer.decode,
|
1100 |
**decoder_raw_kwargs
|
|
|
1107 |
# else hit bitsandbytes lack of thread safety:
|
1108 |
# https://github.com/h2oai/h2ogpt/issues/104
|
1109 |
# but only makes sense if concurrency_count == 1
|
1110 |
+
context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
|
1111 |
print('Pre-Generate: %s' % str(datetime.now()), flush=True)
|
1112 |
decoded_output = None
|
1113 |
with context_class("generate.lock"):
|
|
|
1126 |
inputs_decoded = prompt = inputs_decoded_raw
|
1127 |
decoder = decoder_raw
|
1128 |
decoder_kwargs = decoder_raw_kwargs
|
1129 |
+
elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ',
|
1130 |
+
'') == prompt.replace(
|
1131 |
+
'\n', ' ').replace(' ', ''):
|
1132 |
inputs_decoded = prompt = inputs_decoded_raw
|
1133 |
decoder = decoder_raw
|
1134 |
decoder_kwargs = decoder_raw_kwargs
|
|
|
1136 |
print("WARNING: Special characters in prompt", flush=True)
|
1137 |
if stream_output:
|
1138 |
skip_prompt = False
|
1139 |
+
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
|
1140 |
+
**decoder_kwargs)
|
1141 |
gen_kwargs.update(dict(streamer=streamer))
|
1142 |
+
target = wrapped_partial(generate_with_exceptions, model.generate,
|
1143 |
+
prompt=prompt, inputs_decoded=inputs_decoded,
|
1144 |
+
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
|
1145 |
+
**gen_kwargs)
|
1146 |
bucket = queue.Queue()
|
1147 |
+
thread = EThread(target=target, streamer=streamer, bucket=bucket)
|
1148 |
thread.start()
|
1149 |
outputs = ""
|
1150 |
try:
|
|
|
1176 |
decoded_output = prompt + outputs[0]
|
1177 |
if save_dir and decoded_output:
|
1178 |
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
1179 |
+
print('Post-Generate: %s decoded_output: %s' % (
|
1180 |
+
str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
|
1181 |
+
|
1182 |
+
|
1183 |
+
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
1184 |
+
state_names = ['model_state', 'my_db_state']
|
1185 |
+
inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
|
1186 |
+
|
1187 |
+
|
1188 |
+
def get_cutoffs(is_low_mem, for_context=False):
|
1189 |
+
# help to avoid errors like:
|
1190 |
+
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
1191 |
+
# RuntimeError: expected scalar type Half but found Float
|
1192 |
+
# with - 256
|
1193 |
+
max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
|
1194 |
+
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
1195 |
+
output_smallest = 30 * 4
|
1196 |
+
max_prompt_length = cutoff_len - output_smallest
|
1197 |
+
|
1198 |
+
if for_context:
|
1199 |
+
# then lower even more to avoid later chop, since just estimate tokens in context bot
|
1200 |
+
max_prompt_length = max(64, int(max_prompt_length * 0.8))
|
1201 |
+
|
1202 |
+
return cutoff_len, output_smallest, max_length_tokenize, max_prompt_length
|
1203 |
|
1204 |
|
1205 |
class H2OTextIteratorStreamer(TextIteratorStreamer):
|
|
|
1207 |
normally, timeout required for now to handle exceptions, else get()
|
1208 |
but with H2O version of TextIteratorStreamer, loop over block to handle
|
1209 |
"""
|
1210 |
+
|
1211 |
def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None,
|
1212 |
block=True, **decode_kwargs):
|
1213 |
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
|
|
1234 |
print("hit stop", flush=True)
|
1235 |
# could raise or break, maybe best to raise and make parent see if any exception in thread
|
1236 |
raise StopIteration()
|
1237 |
+
# break
|
1238 |
value = self.text_queue.get(block=self.block, timeout=self.timeout)
|
1239 |
break
|
1240 |
except queue.Empty:
|
|
|
1245 |
return value
|
1246 |
|
1247 |
|
1248 |
+
def generate_with_exceptions(func, *args, prompt='', inputs_decoded='', raise_generate_gpu_exceptions=True, **kwargs):
|
1249 |
try:
|
1250 |
+
func(*args, **kwargs)
|
1251 |
except torch.cuda.OutOfMemoryError as e:
|
1252 |
print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
1253 |
flush=True)
|
1254 |
+
if 'input_ids' in kwargs:
|
1255 |
+
if kwargs['input_ids'] is not None:
|
1256 |
+
kwargs['input_ids'].cpu()
|
1257 |
+
kwargs['input_ids'] = None
|
1258 |
traceback.print_exc()
|
1259 |
clear_torch_cache()
|
1260 |
return
|
|
|
1446 |
|
1447 |
# move to correct position
|
1448 |
for example in examples:
|
1449 |
+
example += [chat, '', '', 'Disabled']
|
1450 |
# adjust examples if non-chat mode
|
1451 |
if not chat:
|
1452 |
example[eval_func_param_names.index('instruction_nochat')] = example[
|
|
|
1455 |
|
1456 |
example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
|
1457 |
example[eval_func_param_names.index('iinput')] = ''
|
1458 |
+
assert len(example) == len(eval_func_param_names), "Wrong example: %s %s" % (
|
1459 |
+
len(example), len(eval_func_param_names))
|
1460 |
|
1461 |
return placeholder_instruction, placeholder_input, \
|
1462 |
+
stream_output, show_examples, \
|
1463 |
+
prompt_type, temperature, top_p, top_k, num_beams, \
|
1464 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time, \
|
1465 |
+
repetition_penalty, num_return_sequences, \
|
1466 |
+
do_sample, \
|
1467 |
+
src_lang, tgt_lang, \
|
1468 |
+
examples, \
|
1469 |
+
task_info
|
1470 |
|
1471 |
|
1472 |
def languages_covered():
|
|
|
1486 |
return context0
|
1487 |
|
1488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1489 |
def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len):
|
1490 |
question = question[-cutoff_len:]
|
1491 |
answer = answer[-cutoff_len:]
|
|
|
1549 |
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
|
1550 |
"""
|
1551 |
fire.Fire(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gpt4all_llm.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
from typing import Dict, Any, Optional, List
|
4 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
5 |
+
from pydantic import root_validator
|
6 |
+
from langchain.llms import gpt4all
|
7 |
+
from dotenv import dotenv_values
|
8 |
+
|
9 |
+
|
10 |
+
class FakeTokenizer:
|
11 |
+
|
12 |
+
def encode(self, x, *args, **kwargs):
|
13 |
+
return dict(input_ids=[x])
|
14 |
+
|
15 |
+
def decode(self, x, *args, **kwargs):
|
16 |
+
return x
|
17 |
+
|
18 |
+
def __call__(self, x, *args, **kwargs):
|
19 |
+
return self.encode(x, *args, **kwargs)
|
20 |
+
|
21 |
+
|
22 |
+
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
23 |
+
# defaults (some of these are generation parameters, so need to be passed in at generation time)
|
24 |
+
model_kwargs = dict(n_ctx=kwargs.get('max_new_tokens', 256),
|
25 |
+
n_threads=os.cpu_count() // 2,
|
26 |
+
temp=kwargs.get('temperature', 0.2),
|
27 |
+
top_p=kwargs.get('top_p', 0.75),
|
28 |
+
top_k=kwargs.get('top_k', 40))
|
29 |
+
env_gpt4all_file = ".env_gpt4all"
|
30 |
+
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
31 |
+
|
32 |
+
if base_model == "llama":
|
33 |
+
if 'model_path_llama' not in model_kwargs:
|
34 |
+
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
|
35 |
+
model_path = model_kwargs.pop('model_path_llama')
|
36 |
+
from gpt4all import GPT4All as GPT4AllModel
|
37 |
+
elif base_model == "gptj":
|
38 |
+
if 'model_path_gptj' not in model_kwargs:
|
39 |
+
raise ValueError("No model_path_gptj in %s" % env_gpt4all_file)
|
40 |
+
model_path = model_kwargs.pop('model_path_gptj')
|
41 |
+
from gpt4all import GPT4All as GPT4AllModel
|
42 |
+
else:
|
43 |
+
raise ValueError("No such base_model %s" % base_model)
|
44 |
+
func_names = list(inspect.signature(GPT4AllModel).parameters)
|
45 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
46 |
+
model = GPT4AllModel(model_path, **model_kwargs)
|
47 |
+
return model, FakeTokenizer(), 'cpu'
|
48 |
+
|
49 |
+
|
50 |
+
def get_llm_gpt4all(model_name, model=None,
|
51 |
+
max_new_tokens=256,
|
52 |
+
temperature=0.1,
|
53 |
+
repetition_penalty=1.0,
|
54 |
+
top_k=40,
|
55 |
+
top_p=0.7):
|
56 |
+
env_gpt4all_file = ".env_gpt4all"
|
57 |
+
model_kwargs = dotenv_values(env_gpt4all_file)
|
58 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
59 |
+
callbacks = [StreamingStdOutCallbackHandler()]
|
60 |
+
n_ctx = model_kwargs.pop('n_ctx', 1024)
|
61 |
+
default_params = {'context_erase': 0.5, 'n_batch': 1, 'n_ctx': n_ctx, 'n_predict': max_new_tokens,
|
62 |
+
'repeat_last_n': 64 if repetition_penalty != 1.0 else 0, 'repeat_penalty': repetition_penalty,
|
63 |
+
'temp': temperature, 'top_k': top_k, 'top_p': top_p}
|
64 |
+
if model_name == 'llama':
|
65 |
+
from langchain.llms import LlamaCpp
|
66 |
+
model_path = model_kwargs.pop('model_path_llama') if model is None else model
|
67 |
+
llm = LlamaCpp(model_path=model_path, n_ctx=n_ctx, callbacks=callbacks, verbose=False)
|
68 |
+
else:
|
69 |
+
model_path = model_kwargs.pop('model_path_gptj') if model is None else model
|
70 |
+
llm = H2OGPT4All(model=model_path, backend='gptj', callbacks=callbacks,
|
71 |
+
verbose=False, **default_params,
|
72 |
+
)
|
73 |
+
return llm
|
74 |
+
|
75 |
+
|
76 |
+
class H2OGPT4All(gpt4all.GPT4All):
|
77 |
+
model: Any
|
78 |
+
"""Path to the pre-trained GPT4All model file."""
|
79 |
+
|
80 |
+
@root_validator()
|
81 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
82 |
+
"""Validate that the python package exists in the environment."""
|
83 |
+
try:
|
84 |
+
if isinstance(values["model"], str):
|
85 |
+
from gpt4all import GPT4All as GPT4AllModel
|
86 |
+
|
87 |
+
full_path = values["model"]
|
88 |
+
model_path, delimiter, model_name = full_path.rpartition("/")
|
89 |
+
model_path += delimiter
|
90 |
+
|
91 |
+
values["client"] = GPT4AllModel(
|
92 |
+
model_name=model_name,
|
93 |
+
model_path=model_path or None,
|
94 |
+
model_type=values["backend"],
|
95 |
+
allow_download=False,
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
values["client"] = values["model"]
|
99 |
+
values["backend"] = values["client"].model.model_type
|
100 |
+
|
101 |
+
except ImportError:
|
102 |
+
raise ValueError(
|
103 |
+
"Could not import gpt4all python package. "
|
104 |
+
"Please install it with `pip install gpt4all`."
|
105 |
+
)
|
106 |
+
return values
|
107 |
+
|
108 |
+
def _call(
|
109 |
+
self,
|
110 |
+
prompt: str,
|
111 |
+
stop: Optional[List[str]] = None,
|
112 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
113 |
+
) -> str:
|
114 |
+
# Roughly 4 chars per token if natural language
|
115 |
+
prompt = prompt[-self.n_ctx * 4:]
|
116 |
+
verbose = False
|
117 |
+
if verbose:
|
118 |
+
print("_call prompt: %s" % prompt, flush=True)
|
119 |
+
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
gpt_langchain.py
ADDED
@@ -0,0 +1,1076 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import inspect
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import pickle
|
6 |
+
import shutil
|
7 |
+
import subprocess
|
8 |
+
import sys
|
9 |
+
import tempfile
|
10 |
+
import traceback
|
11 |
+
import uuid
|
12 |
+
import zipfile
|
13 |
+
from collections import defaultdict
|
14 |
+
from datetime import datetime
|
15 |
+
from functools import reduce
|
16 |
+
from operator import concat
|
17 |
+
|
18 |
+
from joblib import Parallel, delayed
|
19 |
+
|
20 |
+
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
21 |
+
get_device
|
22 |
+
|
23 |
+
import_matplotlib()
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import pandas as pd
|
27 |
+
import requests
|
28 |
+
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
29 |
+
# , GCSDirectoryLoader, GCSFileLoader
|
30 |
+
# , OutlookMessageLoader # GPL3
|
31 |
+
# ImageCaptionLoader, # use our own wrapper
|
32 |
+
# ReadTheDocsLoader, # no special file, some path, so have to give as special option
|
33 |
+
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
34 |
+
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
35 |
+
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
36 |
+
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
|
37 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
38 |
+
from langchain.vectorstores import FAISS
|
39 |
+
from langchain.chains.question_answering import load_qa_chain
|
40 |
+
from langchain.docstore.document import Document
|
41 |
+
from langchain import PromptTemplate
|
42 |
+
from langchain.vectorstores import Chroma
|
43 |
+
|
44 |
+
|
45 |
+
def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
|
46 |
+
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
|
47 |
+
if not sources:
|
48 |
+
return None
|
49 |
+
# get embedding model
|
50 |
+
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
51 |
+
|
52 |
+
# Create vector database
|
53 |
+
if db_type == 'faiss':
|
54 |
+
db = FAISS.from_documents(sources, embedding)
|
55 |
+
elif db_type == 'chroma':
|
56 |
+
collection_name = langchain_mode.replace(' ', '_')
|
57 |
+
os.makedirs(persist_directory, exist_ok=True)
|
58 |
+
db = Chroma.from_documents(documents=sources,
|
59 |
+
embedding=embedding,
|
60 |
+
persist_directory=persist_directory,
|
61 |
+
collection_name=collection_name,
|
62 |
+
anonymized_telemetry=False)
|
63 |
+
db.persist()
|
64 |
+
# FIXME: below just proves can load persistent dir, regenerates its embedding files, so a bit wasteful
|
65 |
+
if False:
|
66 |
+
db = Chroma(embedding_function=embedding,
|
67 |
+
persist_directory=persist_directory,
|
68 |
+
collection_name=collection_name)
|
69 |
+
else:
|
70 |
+
raise RuntimeError("No such db_type=%s" % db_type)
|
71 |
+
|
72 |
+
return db
|
73 |
+
|
74 |
+
|
75 |
+
def add_to_db(db, sources, db_type='faiss', avoid_dup=True):
|
76 |
+
if not sources:
|
77 |
+
return db
|
78 |
+
if db_type == 'faiss':
|
79 |
+
db.add_documents(sources)
|
80 |
+
elif db_type == 'chroma':
|
81 |
+
if avoid_dup:
|
82 |
+
collection = db.get()
|
83 |
+
metadata_sources = set([x['source'] for x in collection['metadatas']])
|
84 |
+
sources = [x for x in sources if x.metadata['source'] not in metadata_sources]
|
85 |
+
if len(sources) == 0:
|
86 |
+
return db
|
87 |
+
db.add_documents(documents=sources)
|
88 |
+
db.persist()
|
89 |
+
else:
|
90 |
+
raise RuntimeError("No such db_type=%s" % db_type)
|
91 |
+
|
92 |
+
return db
|
93 |
+
|
94 |
+
|
95 |
+
def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
|
96 |
+
# Get embedding model
|
97 |
+
if use_openai_embedding:
|
98 |
+
assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
|
99 |
+
from langchain.embeddings import OpenAIEmbeddings
|
100 |
+
embedding = OpenAIEmbeddings()
|
101 |
+
else:
|
102 |
+
# to ensure can fork without deadlock
|
103 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
104 |
+
|
105 |
+
device, torch_dtype, context_class = get_device_dtype()
|
106 |
+
model_kwargs = dict(device=device)
|
107 |
+
embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
|
108 |
+
return embedding
|
109 |
+
|
110 |
+
|
111 |
+
def get_answer_from_sources(chain, sources, question):
|
112 |
+
return chain(
|
113 |
+
{
|
114 |
+
"input_documents": sources,
|
115 |
+
"question": question,
|
116 |
+
},
|
117 |
+
return_only_outputs=True,
|
118 |
+
)["output_text"]
|
119 |
+
|
120 |
+
|
121 |
+
def get_llm(use_openai_model=False, model_name=None, model=None,
|
122 |
+
tokenizer=None, stream_output=False,
|
123 |
+
max_new_tokens=256,
|
124 |
+
temperature=0.1,
|
125 |
+
repetition_penalty=1.0,
|
126 |
+
top_k=40,
|
127 |
+
top_p=0.7,
|
128 |
+
prompt_type=None,
|
129 |
+
):
|
130 |
+
if use_openai_model:
|
131 |
+
from langchain.llms import OpenAI
|
132 |
+
llm = OpenAI(temperature=0)
|
133 |
+
model_name = 'openai'
|
134 |
+
streamer = None
|
135 |
+
elif model_name in ['gptj', 'llama']:
|
136 |
+
from gpt4all_llm import get_llm_gpt4all
|
137 |
+
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
|
138 |
+
temperature=temperature,
|
139 |
+
repetition_penalty=repetition_penalty,
|
140 |
+
top_k=top_k,
|
141 |
+
top_p=top_p,
|
142 |
+
)
|
143 |
+
streamer = None
|
144 |
+
prompt_type = 'plain'
|
145 |
+
else:
|
146 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
147 |
+
|
148 |
+
if model is None:
|
149 |
+
# only used if didn't pass model in
|
150 |
+
assert model_name is None
|
151 |
+
assert tokenizer is None
|
152 |
+
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
153 |
+
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6.9b'
|
154 |
+
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
155 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
156 |
+
device, torch_dtype, context_class = get_device_dtype()
|
157 |
+
|
158 |
+
with context_class(device):
|
159 |
+
load_8bit = True
|
160 |
+
# FIXME: for now not to spread across hetero GPUs
|
161 |
+
# device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
|
162 |
+
device_map = {"": 0} if device == 'cuda' else "auto"
|
163 |
+
model = AutoModelForCausalLM.from_pretrained(model_name,
|
164 |
+
device_map=device_map,
|
165 |
+
torch_dtype=torch_dtype,
|
166 |
+
load_in_8bit=load_8bit)
|
167 |
+
|
168 |
+
gen_kwargs = dict(max_new_tokens=max_new_tokens, return_full_text=True, early_stopping=False)
|
169 |
+
if stream_output:
|
170 |
+
skip_prompt = False
|
171 |
+
from generate import H2OTextIteratorStreamer
|
172 |
+
decoder_kwargs = {}
|
173 |
+
streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
|
174 |
+
gen_kwargs.update(dict(streamer=streamer))
|
175 |
+
else:
|
176 |
+
streamer = None
|
177 |
+
|
178 |
+
if 'h2ogpt' in model_name or prompt_type == 'human_bot':
|
179 |
+
from h2oai_pipeline import H2OTextGenerationPipeline
|
180 |
+
pipe = H2OTextGenerationPipeline(model=model, tokenizer=tokenizer, **gen_kwargs)
|
181 |
+
# pipe.task = "text-generation"
|
182 |
+
# below makes it listen only to our prompt removal, not built in prompt removal that is less general and not specific for our model
|
183 |
+
pipe.task = "text2text-generation"
|
184 |
+
prompt_type = 'human_bot'
|
185 |
+
else:
|
186 |
+
# only for non-instruct tuned cases when ok with just normal next token prediction
|
187 |
+
from transformers import pipeline
|
188 |
+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **gen_kwargs)
|
189 |
+
|
190 |
+
from langchain.llms import HuggingFacePipeline
|
191 |
+
llm = HuggingFacePipeline(pipeline=pipe)
|
192 |
+
return llm, model_name, streamer, prompt_type
|
193 |
+
|
194 |
+
|
195 |
+
def get_device_dtype():
|
196 |
+
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
|
197 |
+
import torch
|
198 |
+
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
199 |
+
device = 'cpu' if n_gpus == 0 else 'cuda'
|
200 |
+
# from utils import NullContext
|
201 |
+
# context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class
|
202 |
+
context_class = torch.device
|
203 |
+
torch_dtype = torch.float16 if device == 'cuda' else torch.float32
|
204 |
+
return device, torch_dtype, context_class
|
205 |
+
|
206 |
+
|
207 |
+
def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
|
208 |
+
"""
|
209 |
+
Get wikipedia data from online
|
210 |
+
:param title:
|
211 |
+
:param first_paragraph_only:
|
212 |
+
:param text_limit:
|
213 |
+
:param take_head:
|
214 |
+
:return:
|
215 |
+
"""
|
216 |
+
filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head)
|
217 |
+
url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}"
|
218 |
+
if first_paragraph_only:
|
219 |
+
url += "&exintro=1"
|
220 |
+
import json
|
221 |
+
if not os.path.isfile(filename):
|
222 |
+
data = requests.get(url).json()
|
223 |
+
json.dump(data, open(filename, 'wt'))
|
224 |
+
else:
|
225 |
+
data = json.load(open(filename, "rt"))
|
226 |
+
page_content = list(data["query"]["pages"].values())[0]["extract"]
|
227 |
+
if take_head is not None and text_limit is not None:
|
228 |
+
page_content = page_content[:text_limit] if take_head else page_content[:-text_limit]
|
229 |
+
title_url = str(title).replace(' ', '_')
|
230 |
+
return Document(
|
231 |
+
page_content=page_content,
|
232 |
+
metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"},
|
233 |
+
)
|
234 |
+
|
235 |
+
|
236 |
+
def get_wiki_sources(first_para=True, text_limit=None):
|
237 |
+
"""
|
238 |
+
Get specific named sources from wikipedia
|
239 |
+
:param first_para:
|
240 |
+
:param text_limit:
|
241 |
+
:return:
|
242 |
+
"""
|
243 |
+
default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux']
|
244 |
+
wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources))
|
245 |
+
return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources]
|
246 |
+
|
247 |
+
|
248 |
+
def get_github_docs(repo_owner, repo_name):
|
249 |
+
"""
|
250 |
+
Access github from specific repo
|
251 |
+
:param repo_owner:
|
252 |
+
:param repo_name:
|
253 |
+
:return:
|
254 |
+
"""
|
255 |
+
with tempfile.TemporaryDirectory() as d:
|
256 |
+
subprocess.check_call(
|
257 |
+
f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .",
|
258 |
+
cwd=d,
|
259 |
+
shell=True,
|
260 |
+
)
|
261 |
+
git_sha = (
|
262 |
+
subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d)
|
263 |
+
.decode("utf-8")
|
264 |
+
.strip()
|
265 |
+
)
|
266 |
+
repo_path = pathlib.Path(d)
|
267 |
+
markdown_files = list(repo_path.glob("*/*.md")) + list(
|
268 |
+
repo_path.glob("*/*.mdx")
|
269 |
+
)
|
270 |
+
for markdown_file in markdown_files:
|
271 |
+
with open(markdown_file, "r") as f:
|
272 |
+
relative_path = markdown_file.relative_to(repo_path)
|
273 |
+
github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}"
|
274 |
+
yield Document(page_content=f.read(), metadata={"source": github_url})
|
275 |
+
|
276 |
+
|
277 |
+
def get_dai_pickle(dest="."):
|
278 |
+
from huggingface_hub import hf_hub_download
|
279 |
+
# True for case when locally already logged in with correct token, so don't have to set key
|
280 |
+
token = os.getenv('HUGGINGFACE_API_TOKEN', True)
|
281 |
+
path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset')
|
282 |
+
shutil.copy(path_to_zip_file, dest)
|
283 |
+
|
284 |
+
|
285 |
+
def get_dai_docs(from_hf=False, get_pickle=True):
|
286 |
+
"""
|
287 |
+
Consume DAI documentation, or consume from public pickle
|
288 |
+
:param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain
|
289 |
+
:param get_pickle: Avoid raw DAI docs, just get pickle directly from HF
|
290 |
+
:return:
|
291 |
+
"""
|
292 |
+
import pickle
|
293 |
+
|
294 |
+
if get_pickle:
|
295 |
+
get_dai_pickle()
|
296 |
+
|
297 |
+
dai_store = 'dai_docs.pickle'
|
298 |
+
dst = "working_dir_docs"
|
299 |
+
if not os.path.isfile(dai_store):
|
300 |
+
from create_data import setup_dai_docs
|
301 |
+
dst = setup_dai_docs(dst=dst, from_hf=from_hf)
|
302 |
+
|
303 |
+
import glob
|
304 |
+
files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
|
305 |
+
|
306 |
+
basedir = os.path.abspath(os.getcwd())
|
307 |
+
from create_data import rst_to_outputs
|
308 |
+
new_outputs = rst_to_outputs(files)
|
309 |
+
os.chdir(basedir)
|
310 |
+
|
311 |
+
pickle.dump(new_outputs, open(dai_store, 'wb'))
|
312 |
+
else:
|
313 |
+
new_outputs = pickle.load(open(dai_store, 'rb'))
|
314 |
+
|
315 |
+
sources = []
|
316 |
+
for line, file in new_outputs:
|
317 |
+
# gradio requires any linked file to be with app.py
|
318 |
+
sym_src = os.path.abspath(os.path.join(dst, file))
|
319 |
+
sym_dst = os.path.abspath(os.path.join(os.getcwd(), file))
|
320 |
+
if os.path.lexists(sym_dst):
|
321 |
+
os.remove(sym_dst)
|
322 |
+
os.symlink(sym_src, sym_dst)
|
323 |
+
itm = Document(page_content=line, metadata={"source": file})
|
324 |
+
# NOTE: yield has issues when going into db, loses metadata
|
325 |
+
# yield itm
|
326 |
+
sources.append(itm)
|
327 |
+
return sources
|
328 |
+
|
329 |
+
|
330 |
+
import distutils.spawn
|
331 |
+
|
332 |
+
have_tesseract = distutils.spawn.find_executable("tesseract")
|
333 |
+
have_libreoffice = distutils.spawn.find_executable("libreoffice")
|
334 |
+
|
335 |
+
import pkg_resources
|
336 |
+
|
337 |
+
try:
|
338 |
+
assert pkg_resources.get_distribution('arxiv') is not None
|
339 |
+
assert pkg_resources.get_distribution('pymupdf') is not None
|
340 |
+
have_arxiv = True
|
341 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
342 |
+
have_arxiv = False
|
343 |
+
|
344 |
+
image_types = ["png", "jpg", "jpeg"]
|
345 |
+
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
346 |
+
"md", "html",
|
347 |
+
"enex", "eml", "epub", "odt", "pptx", "ppt",
|
348 |
+
"zip", "urls",
|
349 |
+
]
|
350 |
+
# "msg", GPL3
|
351 |
+
|
352 |
+
if have_libreoffice:
|
353 |
+
non_image_types.extend(["docx", "doc"])
|
354 |
+
|
355 |
+
file_types = non_image_types + image_types
|
356 |
+
|
357 |
+
|
358 |
+
def add_meta(docs1, file):
|
359 |
+
file_extension = pathlib.Path(file).suffix
|
360 |
+
if not isinstance(docs1, list):
|
361 |
+
docs1 = [docs1]
|
362 |
+
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now))) for x in docs1]
|
363 |
+
|
364 |
+
|
365 |
+
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
|
366 |
+
is_url=False, is_txt=False,
|
367 |
+
enable_captions=True,
|
368 |
+
captions_model=None,
|
369 |
+
enable_ocr=False, caption_loader=None,
|
370 |
+
headsize=50):
|
371 |
+
if file is None:
|
372 |
+
if fail_any_exception:
|
373 |
+
raise RuntimeError("Unexpected None file")
|
374 |
+
else:
|
375 |
+
return []
|
376 |
+
doc1 = [] # in case no support, or disabled support
|
377 |
+
if base_path is None and not is_txt and not is_url:
|
378 |
+
# then assume want to persist but don't care which path used
|
379 |
+
# can't be in base_path
|
380 |
+
dir_name = os.path.dirname(file)
|
381 |
+
base_name = os.path.basename(file)
|
382 |
+
# if from gradio, will have its own temp uuid too, but that's ok
|
383 |
+
base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
|
384 |
+
base_path = os.path.join(dir_name, base_name)
|
385 |
+
if is_url:
|
386 |
+
if file.lower().startswith('arxiv:'):
|
387 |
+
query = file.lower().split('arxiv:')
|
388 |
+
if len(query) == 2 and have_arxiv:
|
389 |
+
query = query[1]
|
390 |
+
docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load()
|
391 |
+
# ensure string, sometimes None
|
392 |
+
[[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1]
|
393 |
+
query_url = f"https://arxiv.org/abs/{query}"
|
394 |
+
[x.metadata.update(
|
395 |
+
dict(source=x.metadata.get('entry_id', query_url), query=query_url,
|
396 |
+
input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in
|
397 |
+
docs1]
|
398 |
+
else:
|
399 |
+
docs1 = []
|
400 |
+
else:
|
401 |
+
docs1 = UnstructuredURLLoader(urls=[file]).load()
|
402 |
+
[x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
|
403 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
404 |
+
elif is_txt:
|
405 |
+
base_path = "user_paste"
|
406 |
+
source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
|
407 |
+
makedirs(os.path.dirname(source_file), exist_ok=True)
|
408 |
+
with open(source_file, "wt") as f:
|
409 |
+
f.write(file)
|
410 |
+
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
|
411 |
+
doc1 = Document(page_content=file, metadata=metadata)
|
412 |
+
elif file.endswith('.html') or file.endswith('.mhtml'):
|
413 |
+
docs1 = UnstructuredHTMLLoader(file_path=file).load()
|
414 |
+
add_meta(docs1, file)
|
415 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
416 |
+
elif (file.endswith('.docx') or file.endswith('.doc')) and have_libreoffice:
|
417 |
+
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
418 |
+
add_meta(docs1, file)
|
419 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
420 |
+
elif file.endswith('.odt'):
|
421 |
+
docs1 = UnstructuredODTLoader(file_path=file).load()
|
422 |
+
add_meta(docs1, file)
|
423 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
424 |
+
elif file.endswith('pptx') or file.endswith('ppt'):
|
425 |
+
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
|
426 |
+
add_meta(docs1, file)
|
427 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
428 |
+
elif file.endswith('.txt'):
|
429 |
+
# use UnstructuredFileLoader ?
|
430 |
+
doc1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
|
431 |
+
add_meta(doc1, file)
|
432 |
+
elif file.endswith('.rtf'):
|
433 |
+
docs1 = UnstructuredRTFLoader(file).load()
|
434 |
+
add_meta(docs1, file)
|
435 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
436 |
+
elif file.endswith('.md'):
|
437 |
+
docs1 = UnstructuredMarkdownLoader(file).load()
|
438 |
+
add_meta(docs1, file)
|
439 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
440 |
+
elif file.endswith('.enex'):
|
441 |
+
doc1 = EverNoteLoader(file).load()
|
442 |
+
add_meta(doc1, file)
|
443 |
+
elif file.endswith('.epub'):
|
444 |
+
docs1 = UnstructuredEPubLoader(file).load()
|
445 |
+
add_meta(docs1, file)
|
446 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
447 |
+
elif file.endswith('.jpeg') or file.endswith('.jpg') or file.endswith('.png'):
|
448 |
+
docs1 = []
|
449 |
+
if have_tesseract and enable_ocr:
|
450 |
+
# OCR, somewhat works, but not great
|
451 |
+
docs1.extend(UnstructuredImageLoader(file).load())
|
452 |
+
add_meta(docs1, file)
|
453 |
+
if enable_captions:
|
454 |
+
# BLIP
|
455 |
+
if caption_loader is not None and not isinstance(caption_loader, (str, bool)):
|
456 |
+
# assumes didn't fork into this process with joblib, else can deadlock
|
457 |
+
caption_loader.set_image_paths([file])
|
458 |
+
docs1c = caption_loader.load()
|
459 |
+
add_meta(docs1c, file)
|
460 |
+
[x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c]
|
461 |
+
docs1.extend(docs1c)
|
462 |
+
else:
|
463 |
+
from image_captions import H2OImageCaptionLoader
|
464 |
+
caption_loader = H2OImageCaptionLoader(caption_gpu=caption_loader == 'gpu',
|
465 |
+
blip_model=captions_model,
|
466 |
+
blip_processor=captions_model)
|
467 |
+
caption_loader.set_image_paths([file])
|
468 |
+
docs1c = caption_loader.load()
|
469 |
+
add_meta(docs1c, file)
|
470 |
+
[x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c]
|
471 |
+
docs1.extend(docs1c)
|
472 |
+
for doci in docs1:
|
473 |
+
doci.metadata['source'] = doci.metadata['image_path']
|
474 |
+
if docs1:
|
475 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
476 |
+
elif file.endswith('.msg'):
|
477 |
+
raise RuntimeError("Not supported, GPL3 license")
|
478 |
+
# docs1 = OutlookMessageLoader(file).load()
|
479 |
+
# docs1[0].metadata['source'] = file
|
480 |
+
elif file.endswith('.eml'):
|
481 |
+
try:
|
482 |
+
docs1 = UnstructuredEmailLoader(file).load()
|
483 |
+
add_meta(docs1, file)
|
484 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
485 |
+
except ValueError as e:
|
486 |
+
if 'text/html content not found in email' in str(e):
|
487 |
+
# e.g. plain/text dict key exists, but not
|
488 |
+
# doc1 = TextLoader(file, encoding="utf8").load()
|
489 |
+
docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
|
490 |
+
add_meta(docs1, file)
|
491 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
492 |
+
else:
|
493 |
+
raise
|
494 |
+
# elif file.endswith('.gcsdir'):
|
495 |
+
# doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
|
496 |
+
# elif file.endswith('.gcsfile'):
|
497 |
+
# doc1 = GCSFileLoader(project_name, bucket, blob).load()
|
498 |
+
elif file.endswith('.rst'):
|
499 |
+
with open(file, "r") as f:
|
500 |
+
doc1 = Document(page_content=f.read(), metadata={"source": file})
|
501 |
+
add_meta(doc1, file)
|
502 |
+
elif file.endswith('.pdf'):
|
503 |
+
# Some PDFs return nothing or junk from PDFMinerLoader
|
504 |
+
# e.g. Beyond fine-tuning_ Classifying high resolution mammograms using function-preserving transformations _ Elsevier Enhanced Reader.pdf
|
505 |
+
doc1 = PyPDFLoader(file).load_and_split()
|
506 |
+
add_meta(doc1, file)
|
507 |
+
elif file.endswith('.csv'):
|
508 |
+
doc1 = CSVLoader(file).load()
|
509 |
+
add_meta(doc1, file)
|
510 |
+
elif file.endswith('.py'):
|
511 |
+
doc1 = PythonLoader(file).load()
|
512 |
+
add_meta(doc1, file)
|
513 |
+
elif file.endswith('.toml'):
|
514 |
+
doc1 = TomlLoader(file).load()
|
515 |
+
add_meta(doc1, file)
|
516 |
+
elif file.endswith('.urls'):
|
517 |
+
with open(file, "r") as f:
|
518 |
+
docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
|
519 |
+
add_meta(docs1, file)
|
520 |
+
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
521 |
+
elif file.endswith('.zip'):
|
522 |
+
with zipfile.ZipFile(file, 'r') as zip_ref:
|
523 |
+
# don't put into temporary path, since want to keep references to docs inside zip
|
524 |
+
# so just extract in path where
|
525 |
+
zip_ref.extractall(base_path)
|
526 |
+
# recurse
|
527 |
+
doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception)
|
528 |
+
else:
|
529 |
+
raise RuntimeError("No file handler for %s" % os.path.basename(file))
|
530 |
+
|
531 |
+
# allow doc1 to be list or not. If not list, did not chunk yet, so chunk now
|
532 |
+
if not isinstance(doc1, list):
|
533 |
+
if chunk:
|
534 |
+
docs = chunk_sources([doc1], chunk_size=chunk_size)
|
535 |
+
else:
|
536 |
+
docs = [doc1]
|
537 |
+
else:
|
538 |
+
docs = doc1
|
539 |
+
|
540 |
+
assert isinstance(docs, list)
|
541 |
+
return docs
|
542 |
+
|
543 |
+
|
544 |
+
def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True, chunk=True, chunk_size=512,
|
545 |
+
is_url=False, is_txt=False,
|
546 |
+
enable_captions=True,
|
547 |
+
captions_model=None,
|
548 |
+
enable_ocr=False, caption_loader=None):
|
549 |
+
if verbose:
|
550 |
+
if is_url:
|
551 |
+
print("Ingesting URL: %s" % file, flush=True)
|
552 |
+
elif is_txt:
|
553 |
+
print("Ingesting Text: %s" % file, flush=True)
|
554 |
+
else:
|
555 |
+
print("Ingesting file: %s" % file, flush=True)
|
556 |
+
res = None
|
557 |
+
try:
|
558 |
+
# don't pass base_path=path, would infinitely recurse
|
559 |
+
res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
|
560 |
+
chunk=chunk, chunk_size=chunk_size,
|
561 |
+
is_url=is_url, is_txt=is_txt,
|
562 |
+
enable_captions=enable_captions,
|
563 |
+
captions_model=captions_model,
|
564 |
+
enable_ocr=enable_ocr,
|
565 |
+
caption_loader=caption_loader)
|
566 |
+
except BaseException as e:
|
567 |
+
print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
|
568 |
+
if fail_any_exception:
|
569 |
+
raise
|
570 |
+
else:
|
571 |
+
exception_doc = Document(
|
572 |
+
page_content='',
|
573 |
+
metadata={"source": file, "exception": str(e), "traceback": traceback.format_exc()})
|
574 |
+
res = [exception_doc]
|
575 |
+
if return_file:
|
576 |
+
base_tmp = "temp_path_to_doc1"
|
577 |
+
if not os.path.isdir(base_tmp):
|
578 |
+
os.makedirs(base_tmp, exist_ok=True)
|
579 |
+
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
|
580 |
+
with open(filename, 'wb') as f:
|
581 |
+
pickle.dump(res, f)
|
582 |
+
return filename
|
583 |
+
return res
|
584 |
+
|
585 |
+
|
586 |
+
def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1,
|
587 |
+
chunk=True, chunk_size=512,
|
588 |
+
url=None, text=None,
|
589 |
+
enable_captions=True,
|
590 |
+
captions_model=None,
|
591 |
+
caption_loader=None,
|
592 |
+
enable_ocr=False,
|
593 |
+
):
|
594 |
+
globs_image_types = []
|
595 |
+
globs_non_image_types = []
|
596 |
+
if path_or_paths is None:
|
597 |
+
return []
|
598 |
+
elif url:
|
599 |
+
globs_non_image_types = [url]
|
600 |
+
elif text:
|
601 |
+
globs_non_image_types = [text]
|
602 |
+
elif isinstance(path_or_paths, str):
|
603 |
+
# single path, only consume allowed files
|
604 |
+
path = path_or_paths
|
605 |
+
# Below globs should match patterns in file_to_doc()
|
606 |
+
[globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
607 |
+
for ftype in image_types]
|
608 |
+
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
609 |
+
for ftype in non_image_types]
|
610 |
+
else:
|
611 |
+
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
612 |
+
assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths)
|
613 |
+
# reform out of allowed types
|
614 |
+
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
615 |
+
# could do below:
|
616 |
+
# globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types])
|
617 |
+
# But instead, allow fail so can collect unsupported too
|
618 |
+
set_globs_image_types = set(globs_image_types)
|
619 |
+
globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types])
|
620 |
+
# could use generator, but messes up metadata handling in recursive case
|
621 |
+
if caption_loader and not isinstance(caption_loader, (bool, str)) and \
|
622 |
+
caption_loader.device != 'cpu' or \
|
623 |
+
get_device() == 'cuda':
|
624 |
+
# to avoid deadlocks, presume was preloaded and so can't fork due to cuda context
|
625 |
+
n_jobs_image = 1
|
626 |
+
else:
|
627 |
+
n_jobs_image = n_jobs
|
628 |
+
|
629 |
+
return_file = True # local choice
|
630 |
+
is_url = url is not None
|
631 |
+
is_txt = text is not None
|
632 |
+
kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
|
633 |
+
return_file=return_file,
|
634 |
+
chunk=chunk, chunk_size=chunk_size,
|
635 |
+
is_url=is_url,
|
636 |
+
is_txt=is_txt,
|
637 |
+
enable_captions=enable_captions,
|
638 |
+
captions_model=captions_model,
|
639 |
+
caption_loader=caption_loader,
|
640 |
+
enable_ocr=enable_ocr,
|
641 |
+
)
|
642 |
+
|
643 |
+
if n_jobs != 1 and len(globs_non_image_types) > 1:
|
644 |
+
# avoid nesting, e.g. upload 1 zip and then inside many files
|
645 |
+
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
|
646 |
+
documents = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
|
647 |
+
delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types
|
648 |
+
)
|
649 |
+
else:
|
650 |
+
documents = [path_to_doc1(file, **kwargs) for file in globs_non_image_types]
|
651 |
+
|
652 |
+
# do images separately since can't fork after cuda in parent, so can't be parallel
|
653 |
+
if n_jobs_image != 1 and len(globs_image_types) > 1:
|
654 |
+
# avoid nesting, e.g. upload 1 zip and then inside many files
|
655 |
+
# harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
|
656 |
+
image_documents = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
|
657 |
+
delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types
|
658 |
+
)
|
659 |
+
else:
|
660 |
+
image_documents = [path_to_doc1(file, **kwargs) for file in globs_image_types]
|
661 |
+
|
662 |
+
# add image docs in
|
663 |
+
documents += image_documents
|
664 |
+
|
665 |
+
if return_file:
|
666 |
+
# then documents really are files
|
667 |
+
files = documents.copy()
|
668 |
+
documents = []
|
669 |
+
for fil in files:
|
670 |
+
with open(fil, 'rb') as f:
|
671 |
+
documents.extend(pickle.load(f))
|
672 |
+
# remove temp pickle
|
673 |
+
os.remove(fil)
|
674 |
+
else:
|
675 |
+
documents = reduce(concat, documents)
|
676 |
+
return documents
|
677 |
+
|
678 |
+
|
679 |
+
def prep_langchain(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, user_path,
|
680 |
+
hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
|
681 |
+
"""
|
682 |
+
do prep first time, involving downloads
|
683 |
+
# FIXME: Add github caching then add here
|
684 |
+
:return:
|
685 |
+
"""
|
686 |
+
assert langchain_mode not in ['MyData'], "Should not prep scratch data"
|
687 |
+
|
688 |
+
if os.path.isdir(persist_directory):
|
689 |
+
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
690 |
+
db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
691 |
+
hf_embedding_model)
|
692 |
+
else:
|
693 |
+
print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
|
694 |
+
db = None
|
695 |
+
if langchain_mode in ['All', 'DriverlessAI docs']:
|
696 |
+
# FIXME: Could also just use dai_docs.pickle directly and upload that
|
697 |
+
get_dai_docs(from_hf=True)
|
698 |
+
|
699 |
+
if langchain_mode in ['All', 'wiki']:
|
700 |
+
get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit'])
|
701 |
+
|
702 |
+
langchain_kwargs = kwargs_make_db.copy()
|
703 |
+
langchain_kwargs.update(locals())
|
704 |
+
db = make_db(**langchain_kwargs)
|
705 |
+
|
706 |
+
return db
|
707 |
+
|
708 |
+
|
709 |
+
def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
710 |
+
hf_embedding_model):
|
711 |
+
if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
|
712 |
+
os.path.join(persist_directory, 'index')):
|
713 |
+
print("DO Loading db: %s" % langchain_mode, flush=True)
|
714 |
+
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
715 |
+
db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
|
716 |
+
collection_name=langchain_mode.replace(' ', '_'))
|
717 |
+
print("DONE Loading db: %s" % langchain_mode, flush=True)
|
718 |
+
return db
|
719 |
+
return None
|
720 |
+
|
721 |
+
|
722 |
+
def make_db(**langchain_kwargs):
|
723 |
+
func_names = list(inspect.signature(_make_db).parameters)
|
724 |
+
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
|
725 |
+
defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()}
|
726 |
+
for k in missing_kwargs:
|
727 |
+
if k in defaults_db:
|
728 |
+
langchain_kwargs[k] = defaults_db[k]
|
729 |
+
# final check for missing
|
730 |
+
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
|
731 |
+
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
732 |
+
# only keep actual used
|
733 |
+
langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
|
734 |
+
return _make_db(**langchain_kwargs)
|
735 |
+
|
736 |
+
|
737 |
+
def _make_db(use_openai_embedding=False,
|
738 |
+
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
739 |
+
first_para=False, text_limit=None, chunk=False, chunk_size=1024,
|
740 |
+
langchain_mode=None,
|
741 |
+
user_path=None,
|
742 |
+
db_type='faiss',
|
743 |
+
load_db_if_exists=False,
|
744 |
+
db=None,
|
745 |
+
n_jobs=-1):
|
746 |
+
persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
|
747 |
+
if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
|
748 |
+
os.path.join(persist_directory, 'index')):
|
749 |
+
assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
|
750 |
+
print("Loading db", flush=True)
|
751 |
+
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
752 |
+
db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
|
753 |
+
collection_name=langchain_mode.replace(' ', '_'))
|
754 |
+
elif not db:
|
755 |
+
assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
|
756 |
+
sources = []
|
757 |
+
print("Generating sources", flush=True)
|
758 |
+
if langchain_mode in ['wiki_full', 'All', "'All'"]:
|
759 |
+
from read_wiki_full import get_all_documents
|
760 |
+
small_test = None
|
761 |
+
print("Generating new wiki", flush=True)
|
762 |
+
sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
|
763 |
+
print("Got new wiki", flush=True)
|
764 |
+
if chunk:
|
765 |
+
sources1 = chunk_sources(sources1, chunk_size=chunk_size)
|
766 |
+
print("Chunked new wiki", flush=True)
|
767 |
+
sources.extend(sources1)
|
768 |
+
if langchain_mode in ['wiki', 'All', "'All'"]:
|
769 |
+
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
|
770 |
+
if chunk:
|
771 |
+
sources1 = chunk_sources(sources1, chunk_size=chunk_size)
|
772 |
+
sources.extend(sources1)
|
773 |
+
if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
|
774 |
+
# sources = get_github_docs("dagster-io", "dagster")
|
775 |
+
sources1 = get_github_docs("h2oai", "h2ogpt")
|
776 |
+
# FIXME: always chunk for now
|
777 |
+
sources1 = chunk_sources(sources1, chunk_size=chunk_size)
|
778 |
+
sources.extend(sources1)
|
779 |
+
if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
|
780 |
+
sources1 = get_dai_docs(from_hf=True)
|
781 |
+
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
|
782 |
+
sources1 = chunk_sources(sources1, chunk_size=chunk_size)
|
783 |
+
sources.extend(sources1)
|
784 |
+
if langchain_mode in ['All', 'UserData']:
|
785 |
+
if user_path:
|
786 |
+
# chunk internally for speed over multiple docs
|
787 |
+
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size)
|
788 |
+
sources.extend(sources1)
|
789 |
+
else:
|
790 |
+
print("Chose UserData but user_path is empty/None", flush=True)
|
791 |
+
if False and langchain_mode in ['urls', 'All', "'All'"]:
|
792 |
+
# from langchain.document_loaders import UnstructuredURLLoader
|
793 |
+
# loader = UnstructuredURLLoader(urls=urls)
|
794 |
+
urls = ["https://www.birdsongsf.com/who-we-are/"]
|
795 |
+
from langchain.document_loaders import PlaywrightURLLoader
|
796 |
+
loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
|
797 |
+
sources1 = loader.load()
|
798 |
+
sources.extend(sources1)
|
799 |
+
if not sources:
|
800 |
+
print("langchain_mode %s has no sources, not making db" % langchain_mode, flush=True)
|
801 |
+
return None
|
802 |
+
print("Generating db", flush=True)
|
803 |
+
db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
|
804 |
+
persist_directory=persist_directory, langchain_mode=langchain_mode,
|
805 |
+
hf_embedding_model=hf_embedding_model)
|
806 |
+
print("Generated db", flush=True)
|
807 |
+
return db
|
808 |
+
|
809 |
+
|
810 |
+
source_prefix = "Sources [Score | Link]:"
|
811 |
+
source_postfix = "End Sources<p>"
|
812 |
+
|
813 |
+
|
814 |
+
def run_qa_db(**kwargs):
|
815 |
+
func_names = list(inspect.signature(_run_qa_db).parameters)
|
816 |
+
# hard-coded defaults
|
817 |
+
kwargs['answer_with_sources'] = True
|
818 |
+
kwargs['sanitize_bot_response'] = True
|
819 |
+
kwargs['show_rank'] = False
|
820 |
+
missing_kwargs = [x for x in func_names if x not in kwargs]
|
821 |
+
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
822 |
+
# only keep actual used
|
823 |
+
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
824 |
+
return _run_qa_db(**kwargs)
|
825 |
+
|
826 |
+
|
827 |
+
def _run_qa_db(query=None,
|
828 |
+
use_openai_model=False, use_openai_embedding=False,
|
829 |
+
first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
|
830 |
+
user_path=None,
|
831 |
+
db_type='faiss',
|
832 |
+
model_name=None, model=None, tokenizer=None,
|
833 |
+
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
834 |
+
stream_output=False,
|
835 |
+
prompter=None,
|
836 |
+
prompt_type=None,
|
837 |
+
answer_with_sources=True,
|
838 |
+
cut_distanct=1.1,
|
839 |
+
sanitize_bot_response=True,
|
840 |
+
show_rank=False,
|
841 |
+
load_db_if_exists=False,
|
842 |
+
db=None,
|
843 |
+
max_new_tokens=256,
|
844 |
+
temperature=0.1,
|
845 |
+
repetition_penalty=1.0,
|
846 |
+
top_k=40,
|
847 |
+
top_p=0.7,
|
848 |
+
langchain_mode=None,
|
849 |
+
n_jobs=-1):
|
850 |
+
"""
|
851 |
+
|
852 |
+
:param query:
|
853 |
+
:param use_openai_model:
|
854 |
+
:param use_openai_embedding:
|
855 |
+
:param first_para:
|
856 |
+
:param text_limit:
|
857 |
+
:param k:
|
858 |
+
:param chunk:
|
859 |
+
:param chunk_size:
|
860 |
+
:param user_path: user path to glob recursively from
|
861 |
+
:param db_type: 'faiss' for in-memory db or 'chroma' for persistent db
|
862 |
+
:param model_name: model name, used to switch behaviors
|
863 |
+
:param model: pre-initialized model, else will make new one
|
864 |
+
:param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
|
865 |
+
:param answer_with_sources
|
866 |
+
:return:
|
867 |
+
"""
|
868 |
+
|
869 |
+
# FIXME: For All just go over all dbs instead of a separate db for All
|
870 |
+
db = make_db(**locals())
|
871 |
+
prompt_type = prompter.prompt_type if prompter is not None else prompt_type
|
872 |
+
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
873 |
+
model=model, tokenizer=tokenizer,
|
874 |
+
stream_output=stream_output,
|
875 |
+
max_new_tokens=max_new_tokens,
|
876 |
+
temperature=temperature,
|
877 |
+
repetition_penalty=repetition_penalty,
|
878 |
+
top_k=top_k,
|
879 |
+
top_p=top_p,
|
880 |
+
prompt_type=prompt_type,
|
881 |
+
)
|
882 |
+
|
883 |
+
if model_name in ['llama', 'gptj']:
|
884 |
+
# FIXME: for now, streams to stdout/stderr currently
|
885 |
+
stream_output = False
|
886 |
+
|
887 |
+
if not use_openai_model and prompt_type not in ['plain'] or model_name in ['llama', 'gptj']:
|
888 |
+
# instruct-like, rather than few-shot prompt_type='plain' as default
|
889 |
+
# but then sources confuse the model with how inserted among rest of text, so avoid
|
890 |
+
prefix = ""
|
891 |
+
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
|
892 |
+
use_context = False
|
893 |
+
template = """%s{context}{question}""" % prefix
|
894 |
+
else:
|
895 |
+
use_context = True
|
896 |
+
template = """%s
|
897 |
+
==
|
898 |
+
{context}
|
899 |
+
==
|
900 |
+
{question}""" % prefix
|
901 |
+
prompt = PromptTemplate(
|
902 |
+
# input_variables=["summaries", "question"],
|
903 |
+
input_variables=["context", "question"],
|
904 |
+
template=template,
|
905 |
+
)
|
906 |
+
chain = load_qa_chain(llm, prompt=prompt)
|
907 |
+
else:
|
908 |
+
chain = load_qa_with_sources_chain(llm)
|
909 |
+
use_context = True
|
910 |
+
|
911 |
+
if query is None:
|
912 |
+
query = "What are the main differences between Linux and Windows?"
|
913 |
+
# https://github.com/hwchase17/langchain/issues/1946
|
914 |
+
# FIXME: Seems to way to get size of chroma db to limit k to avoid
|
915 |
+
# Chroma collection MyData contains fewer than 4 elements.
|
916 |
+
# type logger error
|
917 |
+
k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
|
918 |
+
|
919 |
+
if db and use_context:
|
920 |
+
docs_with_score = db.similarity_search_with_score(query, k=k_db)[:k]
|
921 |
+
# cut off so no high distance docs/sources considered
|
922 |
+
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
923 |
+
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
924 |
+
if len(scores) > 0:
|
925 |
+
print("Distance: min: %s max: %s mean: %s median: %s" %
|
926 |
+
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
927 |
+
else:
|
928 |
+
docs = []
|
929 |
+
scores = []
|
930 |
+
|
931 |
+
if not docs and use_context:
|
932 |
+
return None
|
933 |
+
|
934 |
+
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
935 |
+
if os.path.isfile(common_words_file):
|
936 |
+
df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
|
937 |
+
import string
|
938 |
+
reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
|
939 |
+
reduced_query_words = reduced_query.split(' ')
|
940 |
+
set_common = set(df['Lemma'].values.tolist())
|
941 |
+
num_common = len([x.lower() in set_common for x in reduced_query_words])
|
942 |
+
frac_common = num_common / len(reduced_query)
|
943 |
+
# FIXME: report to user bad query that uses too many common words
|
944 |
+
print("frac_common: %s" % frac_common, flush=True)
|
945 |
+
|
946 |
+
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
|
947 |
+
chain_kwargs = dict(input_documents=[], question=query)
|
948 |
+
else:
|
949 |
+
chain_kwargs = dict(input_documents=docs, question=query)
|
950 |
+
|
951 |
+
if stream_output:
|
952 |
+
answer = None
|
953 |
+
assert streamer is not None
|
954 |
+
target = wrapped_partial(chain, chain_kwargs)
|
955 |
+
import queue
|
956 |
+
bucket = queue.Queue()
|
957 |
+
thread = EThread(target=target, streamer=streamer, bucket=bucket)
|
958 |
+
thread.start()
|
959 |
+
outputs = ""
|
960 |
+
prompt = None # FIXME
|
961 |
+
try:
|
962 |
+
for new_text in streamer:
|
963 |
+
# print("new_text: %s" % new_text, flush=True)
|
964 |
+
if bucket.qsize() > 0 or thread.exc:
|
965 |
+
thread.join()
|
966 |
+
outputs += new_text
|
967 |
+
if prompter: # and False: # FIXME: pipeline can already use prompter
|
968 |
+
output1 = prompter.get_response(outputs, prompt=prompt,
|
969 |
+
sanitize_bot_response=sanitize_bot_response)
|
970 |
+
yield output1
|
971 |
+
else:
|
972 |
+
yield outputs
|
973 |
+
except BaseException:
|
974 |
+
# if any exception, raise that exception if was from thread, first
|
975 |
+
if thread.exc:
|
976 |
+
raise thread.exc
|
977 |
+
raise
|
978 |
+
finally:
|
979 |
+
# in case no exception and didn't join with thread yet, then join
|
980 |
+
if not thread.exc:
|
981 |
+
answer = thread.join()
|
982 |
+
# in case raise StopIteration or broke queue loop in streamer, but still have exception
|
983 |
+
if thread.exc:
|
984 |
+
raise thread.exc
|
985 |
+
# FIXME: answer is not string outputs from streamer. How to get actual final output?
|
986 |
+
# answer = outputs
|
987 |
+
else:
|
988 |
+
answer = chain(chain_kwargs)
|
989 |
+
|
990 |
+
if not use_context:
|
991 |
+
ret = answer['output_text']
|
992 |
+
yield ret
|
993 |
+
elif answer is not None:
|
994 |
+
print("query: %s" % query, flush=True)
|
995 |
+
print("answer: %s" % answer['output_text'], flush=True)
|
996 |
+
# link
|
997 |
+
answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
|
998 |
+
zip(scores, answer['input_documents'])]
|
999 |
+
answer_sources_dict = defaultdict(list)
|
1000 |
+
[answer_sources_dict[url].append(score) for score, url in answer_sources]
|
1001 |
+
answers_dict = {}
|
1002 |
+
for url, scores_url in answer_sources_dict.items():
|
1003 |
+
answers_dict[url] = np.max(scores_url)
|
1004 |
+
answer_sources = [(score, url) for url, score in answers_dict.items()]
|
1005 |
+
answer_sources.sort(key=lambda x: x[0], reverse=True)
|
1006 |
+
if show_rank:
|
1007 |
+
# answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
|
1008 |
+
# sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources)
|
1009 |
+
answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
|
1010 |
+
sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources)
|
1011 |
+
else:
|
1012 |
+
answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources]
|
1013 |
+
sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
|
1014 |
+
sorted_sources_urls += f"</ul></p>{source_postfix}"
|
1015 |
+
|
1016 |
+
if not answer['output_text'].endswith('\n'):
|
1017 |
+
answer['output_text'] += '\n'
|
1018 |
+
|
1019 |
+
if answer_with_sources:
|
1020 |
+
ret = answer['output_text'] + '\n' + sorted_sources_urls
|
1021 |
+
else:
|
1022 |
+
ret = answer['output_text']
|
1023 |
+
|
1024 |
+
yield ret
|
1025 |
+
return
|
1026 |
+
|
1027 |
+
|
1028 |
+
def chunk_sources(sources, chunk_size=1024):
|
1029 |
+
source_chunks = []
|
1030 |
+
# Below for known separator
|
1031 |
+
# splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0)
|
1032 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
|
1033 |
+
for source in sources:
|
1034 |
+
# print(source.metadata['source'], flush=True)
|
1035 |
+
for chunky in splitter.split_text(source.page_content):
|
1036 |
+
source_chunks.append(Document(page_content=chunky, metadata=source.metadata))
|
1037 |
+
return source_chunks
|
1038 |
+
|
1039 |
+
|
1040 |
+
def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'):
|
1041 |
+
from huggingface_hub import hf_hub_download
|
1042 |
+
# True for case when locally already logged in with correct token, so don't have to set key
|
1043 |
+
token = os.getenv('HUGGINGFACE_API_TOKEN', True)
|
1044 |
+
path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
|
1045 |
+
import zipfile
|
1046 |
+
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
|
1047 |
+
zip_ref.extractall(dest)
|
1048 |
+
return path_to_zip_file
|
1049 |
+
|
1050 |
+
|
1051 |
+
# Note dir has space in some cases, while zip does not
|
1052 |
+
some_db_zips = [['db_dir_DriverlessAI_docs.zip', 'db_dir_DriverlessAI docs', 'CC-BY-NC license'],
|
1053 |
+
['db_dir_UserData.zip', 'db_dir_UserData', 'CC-BY license for ArXiv'],
|
1054 |
+
['db_dir_github_h2oGPT.zip', 'db_dir_github h2oGPT', 'ApacheV2 license'],
|
1055 |
+
['db_dir_wiki.zip', 'db_dir_wiki', 'CC-BY-SA Wikipedia license'],
|
1056 |
+
# ['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'],
|
1057 |
+
]
|
1058 |
+
|
1059 |
+
all_db_zips = some_db_zips + \
|
1060 |
+
[['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'],
|
1061 |
+
]
|
1062 |
+
|
1063 |
+
|
1064 |
+
def get_some_dbs_from_hf(dest='.', db_zips=None):
|
1065 |
+
if db_zips is None:
|
1066 |
+
db_zips = some_db_zips
|
1067 |
+
for db_dir, dir_expected, license1 in db_zips:
|
1068 |
+
path_to_zip_file = get_db_from_hf(dest=dest, db_dir=db_dir)
|
1069 |
+
assert os.path.isfile(path_to_zip_file), "Missing zip in %s" % path_to_zip_file
|
1070 |
+
if dir_expected:
|
1071 |
+
assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
|
1072 |
+
assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
|
1073 |
+
|
1074 |
+
|
1075 |
+
if __name__ == '__main__':
|
1076 |
+
pass
|
gradio_runner.py
CHANGED
@@ -1,15 +1,23 @@
|
|
1 |
import copy
|
2 |
import functools
|
3 |
import inspect
|
|
|
4 |
import os
|
|
|
5 |
import sys
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
|
8 |
-
from prompter import Prompter
|
|
|
9 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
10 |
-
ping
|
11 |
-
from
|
12 |
-
|
13 |
|
14 |
import gradio as gr
|
15 |
from apscheduler.schedulers.background import BackgroundScheduler
|
@@ -25,6 +33,21 @@ def go_gradio(**kwargs):
|
|
25 |
model_state0 = kwargs['model_state0']
|
26 |
score_model_state0 = kwargs['score_model_state0']
|
27 |
queue = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# easy update of kwargs needed for evaluate() etc.
|
30 |
kwargs.update(locals())
|
@@ -42,6 +65,9 @@ def go_gradio(**kwargs):
|
|
42 |
title = 'h2oGPT'
|
43 |
if 'h2ogpt-research' in kwargs['base_model']:
|
44 |
title += " [Research demonstration]"
|
|
|
|
|
|
|
45 |
if kwargs['verbose']:
|
46 |
description = f"""Model {kwargs['base_model']} Instruct dataset.
|
47 |
For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
|
@@ -49,9 +75,11 @@ def go_gradio(**kwargs):
|
|
49 |
Hash: {get_githash()}
|
50 |
"""
|
51 |
else:
|
52 |
-
description =
|
53 |
description += "If this host is busy, try [12B](https://gpt.h2o.ai), [30B](http://gpt2.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
54 |
description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
|
|
|
|
|
55 |
|
56 |
if kwargs['verbose']:
|
57 |
task_info_md = f"""
|
@@ -66,6 +94,9 @@ def go_gradio(**kwargs):
|
|
66 |
"""
|
67 |
else:
|
68 |
css_code = """footer {visibility: hidden}"""
|
|
|
|
|
|
|
69 |
|
70 |
if kwargs['gradio_avoid_processing_markdown']:
|
71 |
from gradio_client import utils as client_utils
|
@@ -134,6 +165,8 @@ def go_gradio(**kwargs):
|
|
134 |
model_state2 = gr.State([None, None, None, None])
|
135 |
model_options_state = gr.State([model_options])
|
136 |
lora_options_state = gr.State([lora_options])
|
|
|
|
|
137 |
gr.Markdown(f"""
|
138 |
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
139 |
|
@@ -142,7 +175,7 @@ def go_gradio(**kwargs):
|
|
142 |
""")
|
143 |
if is_hf:
|
144 |
gr.HTML(
|
145 |
-
|
146 |
|
147 |
# go button visible if
|
148 |
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
@@ -153,7 +186,7 @@ def go_gradio(**kwargs):
|
|
153 |
with gr.Row():
|
154 |
col_nochat = gr.Column(visible=not kwargs['chat'])
|
155 |
with col_nochat: # FIXME: for model comparison, and check rest
|
156 |
-
text_output_nochat = gr.Textbox(lines=5, label=output_label0)
|
157 |
instruction_nochat = gr.Textbox(
|
158 |
lines=kwargs['input_lines'],
|
159 |
label=instruction_label_nochat,
|
@@ -187,7 +220,7 @@ def go_gradio(**kwargs):
|
|
187 |
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
188 |
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
189 |
with gr.Row():
|
190 |
-
clear = gr.Button("New Conversation")
|
191 |
flag_btn = gr.Button("Flag")
|
192 |
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
193 |
with gr.Column(visible=kwargs['score_model']):
|
@@ -206,7 +239,7 @@ def go_gradio(**kwargs):
|
|
206 |
score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
|
207 |
retry = gr.Button("Regenerate")
|
208 |
undo = gr.Button("Undo")
|
209 |
-
with gr.TabItem("
|
210 |
with gr.Row():
|
211 |
if 'mbart-' in kwargs['model_lower']:
|
212 |
src_lang = gr.Dropdown(list(languages_covered().keys()),
|
@@ -215,6 +248,122 @@ def go_gradio(**kwargs):
|
|
215 |
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
|
216 |
value=kwargs['tgt_lang'],
|
217 |
label="Output Language")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
with gr.TabItem("Expert"):
|
219 |
with gr.Row():
|
220 |
with gr.Column():
|
@@ -243,7 +392,7 @@ def go_gradio(**kwargs):
|
|
243 |
)
|
244 |
# FIXME: https://github.com/h2oai/h2ogpt/issues/106
|
245 |
if os.getenv('TESTINGFAIL'):
|
246 |
-
|
247 |
else:
|
248 |
max_beams = 1
|
249 |
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
|
@@ -356,12 +505,13 @@ def go_gradio(**kwargs):
|
|
356 |
with gr.Column():
|
357 |
with gr.Row():
|
358 |
system_btn = gr.Button(value='Get System Info')
|
359 |
-
system_text = gr.Textbox(label='System Info', interactive=False)
|
|
|
360 |
|
361 |
with gr.Row():
|
362 |
zip_btn = gr.Button("Zip")
|
363 |
zip_text = gr.Textbox(label="Zip file name", interactive=False)
|
364 |
-
file_output = gr.File()
|
365 |
with gr.Row():
|
366 |
s3up_btn = gr.Button("S3UP")
|
367 |
s3up_text = gr.Textbox(label='S3UP result', interactive=False)
|
@@ -378,8 +528,103 @@ def go_gradio(**kwargs):
|
|
378 |
|
379 |
# Get flagged data
|
380 |
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
|
381 |
-
zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False
|
382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
|
384 |
def check_admin_pass(x):
|
385 |
return gr.update(visible=x == admin_pass)
|
@@ -569,49 +814,66 @@ def go_gradio(**kwargs):
|
|
569 |
"""
|
570 |
# don't deepcopy, can contain model itself
|
571 |
args_list = list(args).copy()
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
573 |
if retry and history:
|
574 |
history.pop()
|
|
|
|
|
|
|
575 |
if not history:
|
576 |
print("No history", flush=True)
|
577 |
history = [['', None]]
|
578 |
yield history, ''
|
579 |
return
|
580 |
# ensure output will be unique to models
|
|
|
581 |
history = copy.deepcopy(history)
|
582 |
instruction1 = history[-1][0]
|
583 |
context1 = ''
|
584 |
-
if
|
585 |
-
|
586 |
-
|
587 |
-
chat_arg_id = eval_func_param_names.index('chat')
|
588 |
-
chat1 = args_list[chat_arg_id]
|
589 |
context1 = ''
|
590 |
-
|
|
|
591 |
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
592 |
prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
|
593 |
chat1, reduced=True)
|
594 |
-
# md -> back to text, maybe not super
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
prompt = prompt.replace('<br>', chat_sep)
|
596 |
-
|
597 |
-
|
598 |
-
|
|
|
|
|
|
|
|
|
599 |
|
600 |
_, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
|
601 |
reduced=True)
|
602 |
if context1 and not context1.endswith(chat_sep):
|
603 |
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
604 |
args_list[0] = instruction1 # override original instruction with history from user
|
605 |
-
|
606 |
-
args_list[2] = context1[-kwargs['chat_history']:]
|
607 |
-
model_state1 = args_list[-2]
|
608 |
if model_state1[0] is None or model_state1[0] == no_model_str:
|
609 |
history = [['', None]]
|
610 |
yield history, ''
|
611 |
return
|
612 |
-
args_list = args_list[:-2]
|
613 |
fun1 = partial(evaluate,
|
614 |
model_state1,
|
|
|
615 |
**kwargs_evaluate)
|
616 |
try:
|
617 |
for output in fun1(*tuple(args_list)):
|
@@ -645,11 +907,11 @@ def go_gradio(**kwargs):
|
|
645 |
outputs=text_output,
|
646 |
)
|
647 |
bot_args = dict(fn=bot,
|
648 |
-
inputs=inputs_list + [model_state] + [text_output],
|
649 |
outputs=[text_output, exception_text],
|
650 |
)
|
651 |
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
652 |
-
inputs=inputs_list + [model_state] + [text_output],
|
653 |
outputs=[text_output, exception_text],
|
654 |
)
|
655 |
undo_user_args = dict(fn=functools.partial(user, undo=True),
|
@@ -663,11 +925,11 @@ def go_gradio(**kwargs):
|
|
663 |
outputs=text_output2,
|
664 |
)
|
665 |
bot_args2 = dict(fn=bot,
|
666 |
-
inputs=inputs_list + [model_state2] + [text_output2],
|
667 |
outputs=[text_output2, exception_text],
|
668 |
)
|
669 |
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
670 |
-
inputs=inputs_list + [model_state2] + [text_output2],
|
671 |
outputs=[text_output2, exception_text],
|
672 |
)
|
673 |
undo_user_args2 = dict(fn=functools.partial(user, undo=True),
|
@@ -694,7 +956,8 @@ def go_gradio(**kwargs):
|
|
694 |
.then(clear_instruct, None, iinput)
|
695 |
submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
|
696 |
queue=queue)
|
697 |
-
submit_event1e = submit_event1d.then(**score_args_submit,
|
|
|
698 |
queue=queue)
|
699 |
submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
|
700 |
queue=queue)
|
@@ -735,12 +998,134 @@ def go_gradio(**kwargs):
|
|
735 |
.then(**score_args_submit, api_name='undo_score' if allow_api else None) \
|
736 |
.then(**score_args2_submit, api_name='undo_score2' if allow_api else None)
|
737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
738 |
# does both models
|
739 |
-
clear.click(
|
740 |
-
|
|
|
|
|
|
|
|
|
741 |
# NOTE: clear of instruction/iinput for nochat has to come after score,
|
742 |
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
|
743 |
-
submit_event_nochat = submit_nochat.click(fun,
|
|
|
744 |
outputs=text_output_nochat,
|
745 |
queue=queue,
|
746 |
api_name='submit_nochat' if allow_api else None) \
|
@@ -842,8 +1227,8 @@ def go_gradio(**kwargs):
|
|
842 |
new_state = [list0[0] + [x]]
|
843 |
new_options = [*new_state[0]]
|
844 |
return gr.Dropdown.update(value=x, choices=new_options), \
|
845 |
-
|
846 |
-
|
847 |
|
848 |
add_model_event = add_model_button.click(fn=dropdown_model_list,
|
849 |
inputs=[model_options_state, new_model],
|
@@ -857,8 +1242,8 @@ def go_gradio(**kwargs):
|
|
857 |
x1 = x if model_used1 == no_model_str else lora_used1
|
858 |
x2 = x if model_used2 == no_model_str else lora_used2
|
859 |
return gr.Dropdown.update(value=x1, choices=new_options), \
|
860 |
-
|
861 |
-
|
862 |
|
863 |
add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
|
864 |
inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
|
@@ -916,10 +1301,20 @@ def go_gradio(**kwargs):
|
|
916 |
|
917 |
scheduler = BackgroundScheduler()
|
918 |
scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
|
919 |
-
if is_public
|
|
|
|
|
|
|
920 |
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
921 |
scheduler.start()
|
922 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
923 |
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
924 |
favicon_path=favicon_path, prevent_thread_lock=True,
|
925 |
auth=kwargs['auth'])
|
@@ -928,9 +1323,7 @@ def go_gradio(**kwargs):
|
|
928 |
demo.block_thread()
|
929 |
|
930 |
|
931 |
-
input_args_list = ['model_state']
|
932 |
-
inputs_kwargs_list = ['debug', 'save_dir', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
|
933 |
-
'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count', 'lora_weights']
|
934 |
|
935 |
|
936 |
def get_inputs_list(inputs_dict, model_lower):
|
@@ -946,9 +1339,204 @@ def get_inputs_list(inputs_dict, model_lower):
|
|
946 |
if k == 'kwargs':
|
947 |
continue
|
948 |
if k in input_args_list + inputs_kwargs_list:
|
949 |
-
# these are added
|
950 |
continue
|
951 |
if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
|
952 |
continue
|
953 |
inputs_list.append(inputs_dict[k])
|
954 |
return inputs_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import copy
|
2 |
import functools
|
3 |
import inspect
|
4 |
+
import json
|
5 |
import os
|
6 |
+
import random
|
7 |
import sys
|
8 |
+
import traceback
|
9 |
+
import uuid
|
10 |
+
import filelock
|
11 |
+
import pandas as pd
|
12 |
+
import tabulate
|
13 |
|
14 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
|
15 |
+
from prompter import Prompter, \
|
16 |
+
prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt
|
17 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
18 |
+
ping, get_short_name, get_url, makedirs
|
19 |
+
from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
|
20 |
+
inputs_kwargs_list, get_cutoffs, scratch_base_dir
|
21 |
|
22 |
import gradio as gr
|
23 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
33 |
model_state0 = kwargs['model_state0']
|
34 |
score_model_state0 = kwargs['score_model_state0']
|
35 |
queue = True
|
36 |
+
dbs = kwargs['dbs']
|
37 |
+
db_type = kwargs['db_type']
|
38 |
+
visible_langchain_modes = kwargs['visible_langchain_modes']
|
39 |
+
allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
|
40 |
+
allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
|
41 |
+
enable_sources_list = kwargs['enable_sources_list']
|
42 |
+
enable_url_upload = kwargs['enable_url_upload']
|
43 |
+
enable_text_upload = kwargs['enable_text_upload']
|
44 |
+
allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
|
45 |
+
use_openai_embedding = kwargs['use_openai_embedding']
|
46 |
+
hf_embedding_model = kwargs['hf_embedding_model']
|
47 |
+
enable_captions = kwargs['enable_captions']
|
48 |
+
captions_model = kwargs['captions_model']
|
49 |
+
enable_ocr = kwargs['enable_ocr']
|
50 |
+
caption_loader = kwargs['caption_loader']
|
51 |
|
52 |
# easy update of kwargs needed for evaluate() etc.
|
53 |
kwargs.update(locals())
|
|
|
65 |
title = 'h2oGPT'
|
66 |
if 'h2ogpt-research' in kwargs['base_model']:
|
67 |
title += " [Research demonstration]"
|
68 |
+
more_info = """For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O-LLMStudio](https://github.com/h2oai/h2o-llmstudio)<br>"""
|
69 |
+
if is_public:
|
70 |
+
more_info += """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="150" height="20" title="GitHub"></iframe>"""
|
71 |
if kwargs['verbose']:
|
72 |
description = f"""Model {kwargs['base_model']} Instruct dataset.
|
73 |
For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
|
|
|
75 |
Hash: {get_githash()}
|
76 |
"""
|
77 |
else:
|
78 |
+
description = more_info
|
79 |
description += "If this host is busy, try [12B](https://gpt.h2o.ai), [30B](http://gpt2.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
80 |
description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
|
81 |
+
if is_hf:
|
82 |
+
description += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
|
83 |
|
84 |
if kwargs['verbose']:
|
85 |
task_info_md = f"""
|
|
|
94 |
"""
|
95 |
else:
|
96 |
css_code = """footer {visibility: hidden}"""
|
97 |
+
css_code += """
|
98 |
+
body.dark{#warning {background-color: #555555};}
|
99 |
+
"""
|
100 |
|
101 |
if kwargs['gradio_avoid_processing_markdown']:
|
102 |
from gradio_client import utils as client_utils
|
|
|
165 |
model_state2 = gr.State([None, None, None, None])
|
166 |
model_options_state = gr.State([model_options])
|
167 |
lora_options_state = gr.State([lora_options])
|
168 |
+
my_db_state = gr.State([None, None])
|
169 |
+
chat_state = gr.State({})
|
170 |
gr.Markdown(f"""
|
171 |
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
172 |
|
|
|
175 |
""")
|
176 |
if is_hf:
|
177 |
gr.HTML(
|
178 |
+
)
|
179 |
|
180 |
# go button visible if
|
181 |
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
|
|
186 |
with gr.Row():
|
187 |
col_nochat = gr.Column(visible=not kwargs['chat'])
|
188 |
with col_nochat: # FIXME: for model comparison, and check rest
|
189 |
+
text_output_nochat = gr.Textbox(lines=5, label=output_label0).style(show_copy_button=True)
|
190 |
instruction_nochat = gr.Textbox(
|
191 |
lines=kwargs['input_lines'],
|
192 |
label=instruction_label_nochat,
|
|
|
220 |
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
221 |
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
222 |
with gr.Row():
|
223 |
+
clear = gr.Button("Save, New Conversation")
|
224 |
flag_btn = gr.Button("Flag")
|
225 |
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
226 |
with gr.Column(visible=kwargs['score_model']):
|
|
|
239 |
score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
|
240 |
retry = gr.Button("Regenerate")
|
241 |
undo = gr.Button("Undo")
|
242 |
+
with gr.TabItem("Chat"):
|
243 |
with gr.Row():
|
244 |
if 'mbart-' in kwargs['model_lower']:
|
245 |
src_lang = gr.Dropdown(list(languages_covered().keys()),
|
|
|
248 |
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
|
249 |
value=kwargs['tgt_lang'],
|
250 |
label="Output Language")
|
251 |
+
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
|
252 |
+
type='value')
|
253 |
+
with gr.Row():
|
254 |
+
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
|
255 |
+
clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
|
256 |
+
chats_row = gr.Row(visible=True).style(equal_height=False)
|
257 |
+
with chats_row:
|
258 |
+
export_chats_btn = gr.Button(value="Export Chats")
|
259 |
+
chats_file = gr.File(interactive=False, label="Download File")
|
260 |
+
chats_row2 = gr.Row(visible=True).style(equal_height=False)
|
261 |
+
with chats_row2:
|
262 |
+
chatsup_output = gr.File(label="Upload Chat File(s)",
|
263 |
+
file_types=['.json'],
|
264 |
+
file_count='multiple',
|
265 |
+
elem_id="warning", elem_classes="feedback")
|
266 |
+
add_to_chats_btn = gr.Button("Add File(s) to Chats")
|
267 |
+
with gr.TabItem("Data Source"):
|
268 |
+
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
|
269 |
+
from_str=True)
|
270 |
+
gr.HTML(value=f"""LangChain Support Disabled<p>
|
271 |
+
Run:<p>
|
272 |
+
<code>
|
273 |
+
python generate.py --langchain_mode=MyData
|
274 |
+
</code>
|
275 |
+
<p>
|
276 |
+
For more options see: {langchain_readme}""",
|
277 |
+
visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
|
278 |
+
data_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
279 |
+
with data_row:
|
280 |
+
if is_hf:
|
281 |
+
# don't show 'wiki' since only usually useful for internal testing at moment
|
282 |
+
no_show_modes = ['Disabled', 'wiki']
|
283 |
+
else:
|
284 |
+
no_show_modes = ['Disabled']
|
285 |
+
allowed_modes = visible_langchain_modes.copy()
|
286 |
+
allowed_modes = [x for x in allowed_modes if x in dbs]
|
287 |
+
allowed_modes += ['ChatLLM', 'LLM']
|
288 |
+
if allow_upload_to_my_data and 'MyData' not in allowed_modes:
|
289 |
+
allowed_modes += ['MyData']
|
290 |
+
if allow_upload_to_user_data and 'UserData' not in allowed_modes:
|
291 |
+
allowed_modes += ['UserData']
|
292 |
+
langchain_mode = gr.Radio(
|
293 |
+
[x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
|
294 |
+
value=kwargs['langchain_mode'],
|
295 |
+
label="Data Source",
|
296 |
+
visible=kwargs['langchain_mode'] != 'Disabled')
|
297 |
+
|
298 |
+
def upload_file(files, x):
|
299 |
+
file_paths = [file.name for file in files]
|
300 |
+
return files, file_paths
|
301 |
+
|
302 |
+
upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
|
303 |
+
equal_height=False)
|
304 |
+
# import control
|
305 |
+
if kwargs['langchain_mode'] != 'Disabled':
|
306 |
+
from gpt_langchain import file_types, have_arxiv
|
307 |
+
else:
|
308 |
+
have_arxiv = False
|
309 |
+
file_types = []
|
310 |
+
with upload_row:
|
311 |
+
file_types_str = '[' + ' '.join(file_types) + ']'
|
312 |
+
fileup_output = gr.File(label=f'Upload {file_types_str}',
|
313 |
+
file_types=file_types,
|
314 |
+
file_count="multiple",
|
315 |
+
elem_id="warning", elem_classes="feedback")
|
316 |
+
with gr.Row():
|
317 |
+
upload_button = gr.UploadButton("Upload %s" % file_types_str,
|
318 |
+
file_types=file_types,
|
319 |
+
file_count="multiple",
|
320 |
+
visible=False,
|
321 |
+
)
|
322 |
+
# add not visible until upload something
|
323 |
+
with gr.Column():
|
324 |
+
add_to_shared_db_btn = gr.Button("Add File(s) to Shared UserData DB",
|
325 |
+
visible=allow_upload_to_user_data) # and False)
|
326 |
+
add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData DB",
|
327 |
+
visible=allow_upload_to_my_data) # and False)
|
328 |
+
url_row = gr.Row(
|
329 |
+
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload).style(
|
330 |
+
equal_height=False)
|
331 |
+
with url_row:
|
332 |
+
url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
|
333 |
+
url_text = gr.Textbox(label=url_label, interactive=True)
|
334 |
+
with gr.Column():
|
335 |
+
url_user_btn = gr.Button(value='Add URL content to Shared UserData DB',
|
336 |
+
visible=allow_upload_to_user_data)
|
337 |
+
url_my_btn = gr.Button(value='Add URL content to Scratch MyData DB',
|
338 |
+
visible=allow_upload_to_my_data)
|
339 |
+
text_row = gr.Row(
|
340 |
+
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload).style(
|
341 |
+
equal_height=False)
|
342 |
+
with text_row:
|
343 |
+
user_text_text = gr.Textbox(label='Paste Text', interactive=True)
|
344 |
+
with gr.Column():
|
345 |
+
user_text_user_btn = gr.Button(value='Add Text to Shared UserData DB',
|
346 |
+
visible=allow_upload_to_user_data)
|
347 |
+
user_text_my_btn = gr.Button(value='Add Text to Scratch MyData DB',
|
348 |
+
visible=allow_upload_to_my_data)
|
349 |
+
# WIP:
|
350 |
+
with gr.Row(visible=False).style(equal_height=False):
|
351 |
+
github_textbox = gr.Textbox(label="Github URL")
|
352 |
+
with gr.Row(visible=True):
|
353 |
+
github_shared_btn = gr.Button(value="Add Github to Shared UserData DB",
|
354 |
+
visible=allow_upload_to_user_data)
|
355 |
+
github_my_btn = gr.Button(value="Add Github to Scratch MyData DB",
|
356 |
+
visible=allow_upload_to_my_data)
|
357 |
+
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
358 |
+
equal_height=False)
|
359 |
+
with sources_row:
|
360 |
+
sources_text = gr.HTML(label='Sources Added', interactive=False)
|
361 |
+
sources_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
362 |
+
equal_height=False)
|
363 |
+
with sources_row2:
|
364 |
+
get_sources_btn = gr.Button(value="Get Sources List for Selected DB")
|
365 |
+
file_source = gr.File(interactive=False, label="Download File with list of Sources")
|
366 |
+
|
367 |
with gr.TabItem("Expert"):
|
368 |
with gr.Row():
|
369 |
with gr.Column():
|
|
|
392 |
)
|
393 |
# FIXME: https://github.com/h2oai/h2ogpt/issues/106
|
394 |
if os.getenv('TESTINGFAIL'):
|
395 |
+
max_beams = 8 if not (is_low_mem or is_public) else 1
|
396 |
else:
|
397 |
max_beams = 1
|
398 |
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
|
|
|
505 |
with gr.Column():
|
506 |
with gr.Row():
|
507 |
system_btn = gr.Button(value='Get System Info')
|
508 |
+
system_text = gr.Textbox(label='System Info', interactive=False).style(
|
509 |
+
show_copy_button=True)
|
510 |
|
511 |
with gr.Row():
|
512 |
zip_btn = gr.Button("Zip")
|
513 |
zip_text = gr.Textbox(label="Zip file name", interactive=False)
|
514 |
+
file_output = gr.File(interactive=False)
|
515 |
with gr.Row():
|
516 |
s3up_btn = gr.Button("S3UP")
|
517 |
s3up_text = gr.Textbox(label='S3UP result', interactive=False)
|
|
|
528 |
|
529 |
# Get flagged data
|
530 |
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
|
531 |
+
zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False,
|
532 |
+
api_name='zip_data' if allow_api else None)
|
533 |
+
s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False,
|
534 |
+
api_name='s3up_data' if allow_api else None)
|
535 |
+
|
536 |
+
def make_add_visible(x):
|
537 |
+
return gr.update(visible=x is not None)
|
538 |
+
|
539 |
+
def clear_file_list():
|
540 |
+
return None
|
541 |
+
|
542 |
+
def make_invisible():
|
543 |
+
return gr.update(visible=False)
|
544 |
+
|
545 |
+
def make_visible():
|
546 |
+
return gr.update(visible=True)
|
547 |
+
|
548 |
+
# add itself to output to ensure shows working and can't click again
|
549 |
+
upload_button.upload(upload_file, inputs=[upload_button, fileup_output],
|
550 |
+
outputs=[upload_button, fileup_output], queue=queue,
|
551 |
+
api_name='upload_file' if allow_api else None) \
|
552 |
+
.then(make_add_visible, fileup_output, add_to_shared_db_btn, queue=queue) \
|
553 |
+
.then(make_add_visible, fileup_output, add_to_my_db_btn, queue=queue) \
|
554 |
+
.then(make_invisible, outputs=upload_button, queue=queue)
|
555 |
+
|
556 |
+
# Add to UserData
|
557 |
+
update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
|
558 |
+
use_openai_embedding=use_openai_embedding,
|
559 |
+
hf_embedding_model=hf_embedding_model,
|
560 |
+
enable_captions=enable_captions,
|
561 |
+
captions_model=captions_model,
|
562 |
+
enable_ocr=enable_ocr,
|
563 |
+
caption_loader=caption_loader,
|
564 |
+
)
|
565 |
+
|
566 |
+
# note for update_user_db_func output is ignored for db
|
567 |
+
add_to_shared_db_btn.click(update_user_db_func,
|
568 |
+
inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
|
569 |
+
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
570 |
+
api_name='add_to_shared' if allow_api else None) \
|
571 |
+
.then(clear_file_list, outputs=fileup_output, queue=queue)
|
572 |
+
|
573 |
+
# .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
|
574 |
+
# .then(make_visible, outputs=upload_button, queue=queue)
|
575 |
+
|
576 |
+
def clear_textbox():
|
577 |
+
return gr.Textbox.update(value='')
|
578 |
+
|
579 |
+
update_user_db_url_func = functools.partial(update_user_db_func, is_url=True)
|
580 |
+
url_user_btn.click(update_user_db_url_func,
|
581 |
+
inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
|
582 |
+
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
583 |
+
api_name='add_url_to_shared' if allow_api else None) \
|
584 |
+
.then(clear_textbox, outputs=url_text, queue=queue)
|
585 |
+
|
586 |
+
update_user_db_txt_func = functools.partial(update_user_db_func, is_txt=True)
|
587 |
+
user_text_user_btn.click(update_user_db_txt_func,
|
588 |
+
inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
|
589 |
+
outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
590 |
+
api_name='add_text_to_shared' if allow_api else None) \
|
591 |
+
.then(clear_textbox, outputs=user_text_text, queue=queue)
|
592 |
+
|
593 |
+
# Add to MyData
|
594 |
+
update_my_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='MyData',
|
595 |
+
use_openai_embedding=use_openai_embedding,
|
596 |
+
hf_embedding_model=hf_embedding_model,
|
597 |
+
enable_captions=enable_captions,
|
598 |
+
captions_model=captions_model,
|
599 |
+
enable_ocr=enable_ocr,
|
600 |
+
caption_loader=caption_loader,
|
601 |
+
)
|
602 |
+
|
603 |
+
add_to_my_db_btn.click(update_my_db_func,
|
604 |
+
inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
|
605 |
+
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
606 |
+
api_name='add_to_my' if allow_api else None) \
|
607 |
+
.then(clear_file_list, outputs=fileup_output, queue=queue)
|
608 |
+
# .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
|
609 |
+
# .then(make_visible, outputs=upload_button, queue=queue)
|
610 |
+
|
611 |
+
update_my_db_url_func = functools.partial(update_my_db_func, is_url=True)
|
612 |
+
url_my_btn.click(update_my_db_url_func,
|
613 |
+
inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
|
614 |
+
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
615 |
+
api_name='add_url_to_my' if allow_api else None) \
|
616 |
+
.then(clear_textbox, outputs=url_text, queue=queue)
|
617 |
+
|
618 |
+
update_my_db_txt_func = functools.partial(update_my_db_func, is_txt=True)
|
619 |
+
user_text_my_btn.click(update_my_db_txt_func,
|
620 |
+
inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
|
621 |
+
outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
|
622 |
+
api_name='add_txt_to_my' if allow_api else None) \
|
623 |
+
.then(clear_textbox, outputs=user_text_text, queue=queue)
|
624 |
+
|
625 |
+
get_sources1 = functools.partial(get_sources, dbs=dbs)
|
626 |
+
get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=file_source, queue=queue,
|
627 |
+
api_name='get_sources' if allow_api else None)
|
628 |
|
629 |
def check_admin_pass(x):
|
630 |
return gr.update(visible=x == admin_pass)
|
|
|
814 |
"""
|
815 |
# don't deepcopy, can contain model itself
|
816 |
args_list = list(args).copy()
|
817 |
+
model_state1 = args_list[-3]
|
818 |
+
my_db_state1 = args_list[-2]
|
819 |
+
history = args_list[-1]
|
820 |
+
|
821 |
+
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
822 |
+
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
823 |
if retry and history:
|
824 |
history.pop()
|
825 |
+
if not args_list[eval_func_param_names.index('do_sample')]:
|
826 |
+
# if was not sampling, no point in retry unless change to sample
|
827 |
+
args_list[eval_func_param_names.index('do_sample')] = True
|
828 |
if not history:
|
829 |
print("No history", flush=True)
|
830 |
history = [['', None]]
|
831 |
yield history, ''
|
832 |
return
|
833 |
# ensure output will be unique to models
|
834 |
+
_, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
|
835 |
history = copy.deepcopy(history)
|
836 |
instruction1 = history[-1][0]
|
837 |
context1 = ''
|
838 |
+
if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
|
839 |
+
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
840 |
+
chat1 = args_list[eval_func_param_names.index('chat')]
|
|
|
|
|
841 |
context1 = ''
|
842 |
+
# - 1 below because current instruction already in history from user()
|
843 |
+
for histi in range(0, len(history) - 1):
|
844 |
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
845 |
prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
|
846 |
chat1, reduced=True)
|
847 |
+
# md -> back to text, maybe not super important if model trained enough
|
848 |
+
if not kwargs['keep_sources_in_context']:
|
849 |
+
from gpt_langchain import source_prefix, source_postfix
|
850 |
+
import re
|
851 |
+
prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
|
852 |
+
flags=re.DOTALL)
|
853 |
+
if prompt.endswith('\n<p>'):
|
854 |
+
prompt = prompt[:-4]
|
855 |
prompt = prompt.replace('<br>', chat_sep)
|
856 |
+
if not prompt.endswith(chat_sep):
|
857 |
+
prompt += chat_sep
|
858 |
+
# most recent first, add older if can
|
859 |
+
# only include desired chat history
|
860 |
+
if len(prompt + context1) > max_prompt_length:
|
861 |
+
break
|
862 |
+
context1 = prompt + context1
|
863 |
|
864 |
_, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
|
865 |
reduced=True)
|
866 |
if context1 and not context1.endswith(chat_sep):
|
867 |
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
868 |
args_list[0] = instruction1 # override original instruction with history from user
|
869 |
+
args_list[2] = context1
|
|
|
|
|
870 |
if model_state1[0] is None or model_state1[0] == no_model_str:
|
871 |
history = [['', None]]
|
872 |
yield history, ''
|
873 |
return
|
|
|
874 |
fun1 = partial(evaluate,
|
875 |
model_state1,
|
876 |
+
my_db_state1,
|
877 |
**kwargs_evaluate)
|
878 |
try:
|
879 |
for output in fun1(*tuple(args_list)):
|
|
|
907 |
outputs=text_output,
|
908 |
)
|
909 |
bot_args = dict(fn=bot,
|
910 |
+
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
911 |
outputs=[text_output, exception_text],
|
912 |
)
|
913 |
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
914 |
+
inputs=inputs_list + [model_state, my_db_state] + [text_output],
|
915 |
outputs=[text_output, exception_text],
|
916 |
)
|
917 |
undo_user_args = dict(fn=functools.partial(user, undo=True),
|
|
|
925 |
outputs=text_output2,
|
926 |
)
|
927 |
bot_args2 = dict(fn=bot,
|
928 |
+
inputs=inputs_list + [model_state2, my_db_state] + [text_output2],
|
929 |
outputs=[text_output2, exception_text],
|
930 |
)
|
931 |
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
932 |
+
inputs=inputs_list + [model_state2, my_db_state] + [text_output2],
|
933 |
outputs=[text_output2, exception_text],
|
934 |
)
|
935 |
undo_user_args2 = dict(fn=functools.partial(user, undo=True),
|
|
|
956 |
.then(clear_instruct, None, iinput)
|
957 |
submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
|
958 |
queue=queue)
|
959 |
+
submit_event1e = submit_event1d.then(**score_args_submit,
|
960 |
+
api_name='instruction_bot_score' if allow_api else None,
|
961 |
queue=queue)
|
962 |
submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
|
963 |
queue=queue)
|
|
|
998 |
.then(**score_args_submit, api_name='undo_score' if allow_api else None) \
|
999 |
.then(**score_args2_submit, api_name='undo_score2' if allow_api else None)
|
1000 |
|
1001 |
+
# MANAGE CHATS
|
1002 |
+
def dedup(short_chat, short_chats):
|
1003 |
+
if short_chat not in short_chats:
|
1004 |
+
return short_chat
|
1005 |
+
for i in range(1, 1000):
|
1006 |
+
short_chat_try = short_chat + "_" + str(i)
|
1007 |
+
if short_chat_try not in short_chats:
|
1008 |
+
return short_chat_try
|
1009 |
+
# fallback and hope for best
|
1010 |
+
short_chat = short_chat + "_" + str(random.random())
|
1011 |
+
return short_chat
|
1012 |
+
|
1013 |
+
def get_short_chat(x, short_chats, short_len=20, words=4):
|
1014 |
+
if x and len(x[0]) == 2 and x[0][0] is not None:
|
1015 |
+
short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
|
1016 |
+
short_chat = dedup(short_chat, short_chats)
|
1017 |
+
else:
|
1018 |
+
short_chat = None
|
1019 |
+
return short_chat
|
1020 |
+
|
1021 |
+
def is_chat_same(x, y):
|
1022 |
+
# <p> etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation
|
1023 |
+
is_same = True
|
1024 |
+
# length of conversation has to be same
|
1025 |
+
if len(x) != len(y):
|
1026 |
+
return False
|
1027 |
+
for stepx, stepy in zip(x, y):
|
1028 |
+
if len(stepx) != len(stepy):
|
1029 |
+
# something off with a conversation
|
1030 |
+
return False
|
1031 |
+
if len(stepx) != 2:
|
1032 |
+
# something off
|
1033 |
+
return False
|
1034 |
+
if len(stepy) != 2:
|
1035 |
+
# something off
|
1036 |
+
return False
|
1037 |
+
questionx = stepx[0].replace('<p>', '').replace('</p>', '')
|
1038 |
+
answerx = stepx[1].replace('<p>', '').replace('</p>', '')
|
1039 |
+
|
1040 |
+
questiony = stepy[0].replace('<p>', '').replace('</p>', '')
|
1041 |
+
answery = stepy[1].replace('<p>', '').replace('</p>', '')
|
1042 |
+
|
1043 |
+
if questionx != questiony or answerx != answery:
|
1044 |
+
return False
|
1045 |
+
return is_same
|
1046 |
+
|
1047 |
+
def save_chat(chat1, chat2, chat_state1):
|
1048 |
+
short_chats = list(chat_state1.keys())
|
1049 |
+
for chati in [chat1, chat2]:
|
1050 |
+
if chati and len(chati) > 0 and len(chati[0]) == 2 and chati[0][1] is not None:
|
1051 |
+
short_chat = get_short_chat(chati, short_chats)
|
1052 |
+
if short_chat:
|
1053 |
+
already_exists = any([is_chat_same(chati, x) for x in chat_state1.values()])
|
1054 |
+
if not already_exists:
|
1055 |
+
chat_state1[short_chat] = chati
|
1056 |
+
return chat_state1
|
1057 |
+
|
1058 |
+
def update_radio_chats(chat_state1):
|
1059 |
+
return gr.update(choices=list(chat_state1.keys()), value=None)
|
1060 |
+
|
1061 |
+
def deselect_radio_chats():
|
1062 |
+
return gr.update(value=None)
|
1063 |
+
|
1064 |
+
def switch_chat(chat_key, chat_state1):
|
1065 |
+
chosen_chat = chat_state1[chat_key]
|
1066 |
+
return chosen_chat, chosen_chat
|
1067 |
+
|
1068 |
+
radio_chats.input(switch_chat, inputs=[radio_chats, chat_state], outputs=[text_output, text_output2])
|
1069 |
+
|
1070 |
+
def remove_chat(chat_key, chat_state1):
|
1071 |
+
chat_state1.pop(chat_key, None)
|
1072 |
+
return chat_state1
|
1073 |
+
|
1074 |
+
remove_chat_btn.click(remove_chat, inputs=[radio_chats, chat_state], outputs=chat_state) \
|
1075 |
+
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats)
|
1076 |
+
|
1077 |
+
def get_chats1(chat_state1):
|
1078 |
+
base = 'chats'
|
1079 |
+
makedirs(base, exist_ok=True)
|
1080 |
+
filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4()))
|
1081 |
+
with open(filename, "wt") as f:
|
1082 |
+
f.write(json.dumps(chat_state1, indent=2))
|
1083 |
+
return filename
|
1084 |
+
|
1085 |
+
export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
|
1086 |
+
api_name='export_chats' if allow_api else None)
|
1087 |
+
|
1088 |
+
def add_chats_from_file(file, chat_state1, add_btn):
|
1089 |
+
if isinstance(file, str):
|
1090 |
+
files = [file]
|
1091 |
+
else:
|
1092 |
+
files = file
|
1093 |
+
for file1 in files:
|
1094 |
+
try:
|
1095 |
+
if hasattr(file1, 'name'):
|
1096 |
+
file1 = file1.name
|
1097 |
+
with open(file1, "rt") as f:
|
1098 |
+
new_chats = json.loads(f.read())
|
1099 |
+
for chat1_k, chat1_v in new_chats.items():
|
1100 |
+
# ignore chat1_k, regenerate and de-dup to avoid loss
|
1101 |
+
chat_state1 = save_chat(chat1_v, None, chat_state1)
|
1102 |
+
except BaseException as e:
|
1103 |
+
print("Add chats exception: %s" % str(e), flush=True)
|
1104 |
+
return chat_state1, add_btn
|
1105 |
+
|
1106 |
+
# note for update_user_db_func output is ignored for db
|
1107 |
+
add_to_chats_btn.click(add_chats_from_file,
|
1108 |
+
inputs=[chatsup_output, chat_state, add_to_chats_btn],
|
1109 |
+
outputs=[chat_state, add_to_my_db_btn], queue=False,
|
1110 |
+
api_name='add_to_chats' if allow_api else None) \
|
1111 |
+
.then(clear_file_list, outputs=chatsup_output, queue=False) \
|
1112 |
+
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False)
|
1113 |
+
|
1114 |
+
clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
|
1115 |
+
.then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \
|
1116 |
+
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False)
|
1117 |
+
|
1118 |
# does both models
|
1119 |
+
clear.click(save_chat, inputs=[text_output, text_output2, chat_state], outputs=chat_state,
|
1120 |
+
api_name='save_chat' if allow_api else None) \
|
1121 |
+
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
|
1122 |
+
api_name='update_chats' if allow_api else None) \
|
1123 |
+
.then(lambda: None, None, text_output, queue=False, api_name='clearB' if allow_api else None) \
|
1124 |
+
.then(lambda: None, None, text_output2, queue=False, api_name='clearB2' if allow_api else None)
|
1125 |
# NOTE: clear of instruction/iinput for nochat has to come after score,
|
1126 |
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
|
1127 |
+
submit_event_nochat = submit_nochat.click(fun,
|
1128 |
+
inputs=[model_state, my_db_state] + inputs_list,
|
1129 |
outputs=text_output_nochat,
|
1130 |
queue=queue,
|
1131 |
api_name='submit_nochat' if allow_api else None) \
|
|
|
1227 |
new_state = [list0[0] + [x]]
|
1228 |
new_options = [*new_state[0]]
|
1229 |
return gr.Dropdown.update(value=x, choices=new_options), \
|
1230 |
+
gr.Dropdown.update(value=x, choices=new_options), \
|
1231 |
+
'', new_state
|
1232 |
|
1233 |
add_model_event = add_model_button.click(fn=dropdown_model_list,
|
1234 |
inputs=[model_options_state, new_model],
|
|
|
1242 |
x1 = x if model_used1 == no_model_str else lora_used1
|
1243 |
x2 = x if model_used2 == no_model_str else lora_used2
|
1244 |
return gr.Dropdown.update(value=x1, choices=new_options), \
|
1245 |
+
gr.Dropdown.update(value=x2, choices=new_options), \
|
1246 |
+
'', new_state
|
1247 |
|
1248 |
add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
|
1249 |
inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
|
|
|
1301 |
|
1302 |
scheduler = BackgroundScheduler()
|
1303 |
scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
|
1304 |
+
if is_public and \
|
1305 |
+
kwargs['base_model'] not in ['gptj', 'llama']:
|
1306 |
+
# FIXME: disable for gptj, langchain or gpt4all modify print itself
|
1307 |
+
# FIXME: and any multi-threaded/async print will enter model output!
|
1308 |
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
1309 |
scheduler.start()
|
1310 |
|
1311 |
+
# import control
|
1312 |
+
if kwargs['langchain_mode'] == 'Disabled' and \
|
1313 |
+
os.environ.get("TEST_LANGCHAIN_IMPORT") and \
|
1314 |
+
kwargs['base_model'] not in ['gptj', 'llama']:
|
1315 |
+
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
1316 |
+
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
1317 |
+
|
1318 |
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
1319 |
favicon_path=favicon_path, prevent_thread_lock=True,
|
1320 |
auth=kwargs['auth'])
|
|
|
1323 |
demo.block_thread()
|
1324 |
|
1325 |
|
1326 |
+
input_args_list = ['model_state', 'my_db_state']
|
|
|
|
|
1327 |
|
1328 |
|
1329 |
def get_inputs_list(inputs_dict, model_lower):
|
|
|
1339 |
if k == 'kwargs':
|
1340 |
continue
|
1341 |
if k in input_args_list + inputs_kwargs_list:
|
1342 |
+
# these are added at use time for args or partial for kwargs, not taken as input
|
1343 |
continue
|
1344 |
if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
|
1345 |
continue
|
1346 |
inputs_list.append(inputs_dict[k])
|
1347 |
return inputs_list
|
1348 |
+
|
1349 |
+
|
1350 |
+
def get_sources(db1, langchain_mode, dbs=None):
|
1351 |
+
if langchain_mode in ['ChatLLM', 'LLM']:
|
1352 |
+
source_files_added = "NA"
|
1353 |
+
elif langchain_mode in ['wiki_full']:
|
1354 |
+
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
1355 |
+
" Ask jon.mckinney@h2o.ai for file if required."
|
1356 |
+
elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
|
1357 |
+
db_get = db1[0].get()
|
1358 |
+
source_files_added = '\n'.join(sorted(set([x['source'] for x in db_get['metadatas']])))
|
1359 |
+
elif langchain_mode in dbs and dbs[langchain_mode] is not None:
|
1360 |
+
db1 = dbs[langchain_mode]
|
1361 |
+
db_get = db1.get()
|
1362 |
+
source_files_added = '\n'.join(sorted(set([x['source'] for x in db_get['metadatas']])))
|
1363 |
+
else:
|
1364 |
+
source_files_added = "None"
|
1365 |
+
sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
|
1366 |
+
with open(sources_file, "wt") as f:
|
1367 |
+
f.write(source_files_added)
|
1368 |
+
return sources_file
|
1369 |
+
|
1370 |
+
|
1371 |
+
def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
|
1372 |
+
try:
|
1373 |
+
return _update_user_db(file, db1, x, y, *args, dbs=dbs, langchain_mode=langchain_mode, **kwargs)
|
1374 |
+
except BaseException as e:
|
1375 |
+
print(traceback.format_exc(), flush=True)
|
1376 |
+
# gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox
|
1377 |
+
ex_str = "Exception: %s" % str(e)
|
1378 |
+
source_files_added = """\
|
1379 |
+
<html>
|
1380 |
+
<body>
|
1381 |
+
<p>
|
1382 |
+
Sources: <br>
|
1383 |
+
</p>
|
1384 |
+
<div style="overflow-y: auto;height:400px">
|
1385 |
+
{0}
|
1386 |
+
</div>
|
1387 |
+
</body>
|
1388 |
+
</html>
|
1389 |
+
""".format(ex_str)
|
1390 |
+
if langchain_mode == 'MyData':
|
1391 |
+
return db1, x, y, source_files_added
|
1392 |
+
else:
|
1393 |
+
return x, y, source_files_added
|
1394 |
+
|
1395 |
+
|
1396 |
+
def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='UserData', use_openai_embedding=False,
|
1397 |
+
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1398 |
+
caption_loader=None,
|
1399 |
+
enable_captions=True,
|
1400 |
+
captions_model="Salesforce/blip-image-captioning-base",
|
1401 |
+
enable_ocr=False,
|
1402 |
+
verbose=False,
|
1403 |
+
chunk=True, chunk_size=512, is_url=False, is_txt=False):
|
1404 |
+
assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
|
1405 |
+
assert db_type in ['faiss', 'chroma'], "db_type %s not supported" % db_type
|
1406 |
+
from gpt_langchain import add_to_db, get_db, path_to_docs
|
1407 |
+
# handle case of list of temp buffer
|
1408 |
+
if isinstance(file, list) and len(file) > 0 and hasattr(file[0], 'name'):
|
1409 |
+
file = [x.name for x in file]
|
1410 |
+
# handle single file of temp buffer
|
1411 |
+
if hasattr(file, 'name'):
|
1412 |
+
file = file.name
|
1413 |
+
if verbose:
|
1414 |
+
print("Adding %s" % file, flush=True)
|
1415 |
+
sources = path_to_docs(file if not is_url and not is_txt else None,
|
1416 |
+
verbose=verbose, chunk=chunk, chunk_size=chunk_size,
|
1417 |
+
url=file if is_url else None,
|
1418 |
+
text=file if is_txt else None,
|
1419 |
+
enable_captions=enable_captions,
|
1420 |
+
captions_model=captions_model,
|
1421 |
+
enable_ocr=enable_ocr,
|
1422 |
+
caption_loader=caption_loader,
|
1423 |
+
)
|
1424 |
+
exceptions = [x for x in sources if x.metadata.get('exception')]
|
1425 |
+
sources = [x for x in sources if 'exception' not in x.metadata]
|
1426 |
+
|
1427 |
+
with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
|
1428 |
+
if langchain_mode == 'MyData':
|
1429 |
+
if db1[0] is not None:
|
1430 |
+
# then add
|
1431 |
+
add_to_db(db1[0], sources, db_type=db_type)
|
1432 |
+
else:
|
1433 |
+
assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
|
1434 |
+
# then create
|
1435 |
+
# assign fresh hash for this user session, so not shared
|
1436 |
+
# if added has to original state and didn't change, then would be shared db for all users
|
1437 |
+
db1[1] = str(uuid.uuid4())
|
1438 |
+
persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
|
1439 |
+
db1[0] = get_db(sources, use_openai_embedding=use_openai_embedding,
|
1440 |
+
db_type=db_type,
|
1441 |
+
persist_directory=persist_directory,
|
1442 |
+
langchain_mode=langchain_mode,
|
1443 |
+
hf_embedding_model=hf_embedding_model)
|
1444 |
+
if db1[0] is None:
|
1445 |
+
db1[1] = None
|
1446 |
+
source_files_added = get_source_files(db1[0], exceptions=exceptions)
|
1447 |
+
return db1, x, y, source_files_added
|
1448 |
+
else:
|
1449 |
+
persist_directory = 'db_dir_%s' % langchain_mode
|
1450 |
+
if langchain_mode in dbs and dbs[langchain_mode] is not None:
|
1451 |
+
# then add
|
1452 |
+
add_to_db(dbs[langchain_mode], sources, db_type=db_type)
|
1453 |
+
else:
|
1454 |
+
# then create
|
1455 |
+
db = get_db(sources, use_openai_embedding=use_openai_embedding,
|
1456 |
+
db_type=db_type,
|
1457 |
+
persist_directory=persist_directory,
|
1458 |
+
langchain_mode=langchain_mode,
|
1459 |
+
hf_embedding_model=hf_embedding_model)
|
1460 |
+
dbs[langchain_mode] = db
|
1461 |
+
# NOTE we do not return db, because function call always same code path
|
1462 |
+
# return dbs[langchain_mode], x, y
|
1463 |
+
# db in this code path is updated in place
|
1464 |
+
source_files_added = get_source_files(dbs[langchain_mode], exceptions=exceptions)
|
1465 |
+
return x, y, source_files_added
|
1466 |
+
|
1467 |
+
|
1468 |
+
def get_source_files(db, exceptions=None):
|
1469 |
+
if exceptions is None:
|
1470 |
+
exceptions = []
|
1471 |
+
|
1472 |
+
if db is not None:
|
1473 |
+
metadatas = db.get()['metadatas']
|
1474 |
+
else:
|
1475 |
+
metadatas = []
|
1476 |
+
|
1477 |
+
# below automatically de-dups
|
1478 |
+
from gpt_langchain import get_url
|
1479 |
+
small_dict = {get_url(x['source'], from_str=True, short_name=True): get_short_name(x.get('head')) for x in
|
1480 |
+
metadatas}
|
1481 |
+
# if small_dict is empty dict, that's ok
|
1482 |
+
df = pd.DataFrame(small_dict.items(), columns=['source', 'head'])
|
1483 |
+
df.index = df.index + 1
|
1484 |
+
df.index.name = 'index'
|
1485 |
+
source_files_added = tabulate.tabulate(df, headers='keys', tablefmt='unsafehtml')
|
1486 |
+
|
1487 |
+
if exceptions:
|
1488 |
+
exception_metadatas = [x.metadata for x in exceptions]
|
1489 |
+
small_dict = {get_url(x['source'], from_str=True, short_name=True): get_short_name(x.get('exception')) for x in
|
1490 |
+
exception_metadatas}
|
1491 |
+
# if small_dict is empty dict, that's ok
|
1492 |
+
df = pd.DataFrame(small_dict.items(), columns=['source', 'exception'])
|
1493 |
+
df.index = df.index + 1
|
1494 |
+
df.index.name = 'index'
|
1495 |
+
exceptions_html = tabulate.tabulate(df, headers='keys', tablefmt='unsafehtml')
|
1496 |
+
else:
|
1497 |
+
exceptions_html = ''
|
1498 |
+
|
1499 |
+
if metadatas and exceptions:
|
1500 |
+
source_files_added = """\
|
1501 |
+
<html>
|
1502 |
+
<body>
|
1503 |
+
<p>
|
1504 |
+
Sources: <br>
|
1505 |
+
</p>
|
1506 |
+
<div style="overflow-y: auto;height:400px">
|
1507 |
+
{0}
|
1508 |
+
{1}
|
1509 |
+
</div>
|
1510 |
+
</body>
|
1511 |
+
</html>
|
1512 |
+
""".format(source_files_added, exceptions_html)
|
1513 |
+
elif metadatas:
|
1514 |
+
source_files_added = """\
|
1515 |
+
<html>
|
1516 |
+
<body>
|
1517 |
+
<p>
|
1518 |
+
Sources: <br>
|
1519 |
+
</p>
|
1520 |
+
<div style="overflow-y: auto;height:400px">
|
1521 |
+
{0}
|
1522 |
+
</div>
|
1523 |
+
</body>
|
1524 |
+
</html>
|
1525 |
+
""".format(source_files_added)
|
1526 |
+
elif exceptions_html:
|
1527 |
+
source_files_added = """\
|
1528 |
+
<html>
|
1529 |
+
<body>
|
1530 |
+
<p>
|
1531 |
+
Exceptions: <br>
|
1532 |
+
</p>
|
1533 |
+
<div style="overflow-y: auto;height:400px">
|
1534 |
+
{0}
|
1535 |
+
</div>
|
1536 |
+
</body>
|
1537 |
+
</html>
|
1538 |
+
""".format(exceptions_html)
|
1539 |
+
else:
|
1540 |
+
source_files_added = ""
|
1541 |
+
|
1542 |
+
return source_files_added
|
gradio_themes.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from __future__ import annotations
|
2 |
from gradio.themes.soft import Soft
|
3 |
-
from gradio.themes
|
|
|
4 |
|
5 |
h2o_yellow = Color(
|
6 |
name="yellow",
|
@@ -74,6 +75,7 @@ class H2oTheme(Soft):
|
|
74 |
body_background_fill_dark="*neutral_900",
|
75 |
background_fill_primary_dark="*block_background_fill",
|
76 |
block_radius="0 0 8px 8px",
|
|
|
77 |
)
|
78 |
|
79 |
|
|
|
1 |
from __future__ import annotations
|
2 |
from gradio.themes.soft import Soft
|
3 |
+
from gradio.themes import Color
|
4 |
+
from gradio.themes.utils import colors, sizes
|
5 |
|
6 |
h2o_yellow = Color(
|
7 |
name="yellow",
|
|
|
75 |
body_background_fill_dark="*neutral_900",
|
76 |
background_fill_primary_dark="*block_background_fill",
|
77 |
block_radius="0 0 8px 8px",
|
78 |
+
checkbox_label_text_color_selected_dark='#000000',
|
79 |
)
|
80 |
|
81 |
|
h2oai_pipeline.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TextGenerationPipeline
|
2 |
+
from transformers.pipelines.text_generation import ReturnType
|
3 |
+
|
4 |
+
from stopping import get_stopping
|
5 |
+
|
6 |
+
prompt_type = "human_bot"
|
7 |
+
human = "<human>:"
|
8 |
+
bot = "<bot>:"
|
9 |
+
|
10 |
+
# human-bot interaction like OIG dataset
|
11 |
+
prompt = """{human} {instruction}
|
12 |
+
{bot}""".format(
|
13 |
+
human=human,
|
14 |
+
instruction="{instruction}",
|
15 |
+
bot=bot,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
20 |
+
def __init__(self, *args, use_prompter=False, debug=False, chat=False, stream_output=False,
|
21 |
+
sanitize_bot_response=True, **kwargs):
|
22 |
+
super().__init__(*args, **kwargs)
|
23 |
+
self.use_prompter = use_prompter
|
24 |
+
self.prompt_text = None
|
25 |
+
if self.use_prompter:
|
26 |
+
from prompter import Prompter
|
27 |
+
self.prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
28 |
+
else:
|
29 |
+
self.prompter = None
|
30 |
+
self.sanitize_bot_response = sanitize_bot_response
|
31 |
+
|
32 |
+
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
33 |
+
prompt_text = prompt.format(instruction=prompt_text)
|
34 |
+
self.prompt_text = prompt_text
|
35 |
+
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
36 |
+
**generate_kwargs)
|
37 |
+
|
38 |
+
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
39 |
+
records = super().postprocess(model_outputs, return_type=return_type,
|
40 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces)
|
41 |
+
for rec in records:
|
42 |
+
if self.use_prompter:
|
43 |
+
outputs = rec['generated_text']
|
44 |
+
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
45 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
46 |
+
else:
|
47 |
+
outputs = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
|
48 |
+
rec['generated_text'] = outputs
|
49 |
+
return records
|
50 |
+
|
51 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
52 |
+
stopping_criteria = get_stopping(prompt_type, self.tokenizer, self.device, human=human, bot=bot)
|
53 |
+
generate_kwargs['stopping_criteria'] = stopping_criteria
|
54 |
+
return super()._forward(model_inputs, **generate_kwargs)
|
loaders.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_loaders(llama_type, model_name, reward_type):
|
2 |
+
# NOTE: Some models need specific new prompt_type
|
3 |
+
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
4 |
+
if llama_type:
|
5 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
6 |
+
model_loader = LlamaForCausalLM
|
7 |
+
tokenizer_loader = LlamaTokenizer
|
8 |
+
elif 'distilgpt2' in model_name.lower():
|
9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
+
return AutoModelForCausalLM, AutoTokenizer
|
11 |
+
elif 'gpt2' in model_name.lower():
|
12 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
13 |
+
return GPT2LMHeadModel, GPT2Tokenizer
|
14 |
+
elif 'mbart-' in model_name.lower():
|
15 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
16 |
+
return MBartForConditionalGeneration, MBart50TokenizerFast
|
17 |
+
elif 't5' == model_name.lower() or \
|
18 |
+
't5-' in model_name.lower() or \
|
19 |
+
'flan-' in model_name.lower():
|
20 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
21 |
+
return T5ForConditionalGeneration, AutoTokenizer
|
22 |
+
elif 'bigbird' in model_name:
|
23 |
+
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
24 |
+
return BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
25 |
+
elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
|
26 |
+
from transformers import pipeline
|
27 |
+
return pipeline, "summarization"
|
28 |
+
elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
|
29 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
30 |
+
return AutoModelForSequenceClassification, AutoTokenizer
|
31 |
+
else:
|
32 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
33 |
+
model_loader = AutoModelForCausalLM
|
34 |
+
tokenizer_loader = AutoTokenizer
|
35 |
+
return model_loader, tokenizer_loader
|
36 |
+
|
37 |
+
|
38 |
+
def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
|
39 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
40 |
+
local_files_only=local_files_only,
|
41 |
+
resume_download=resume_download,
|
42 |
+
use_auth_token=use_auth_token)
|
43 |
+
|
44 |
+
tokenizer.pad_token_id = 0 # different from the eos token
|
45 |
+
# when generating, we will use the logits of right-most token to predict the next token
|
46 |
+
# so the padding should be on the left,
|
47 |
+
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
48 |
+
tokenizer.padding_side = "left" # Allow batched inference
|
49 |
+
|
50 |
+
return tokenizer
|
prompter.py
CHANGED
@@ -1,4 +1,355 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
class Prompter(object):
|
@@ -13,6 +364,12 @@ class Prompter(object):
|
|
13 |
self.stream_output = stream_output
|
14 |
self.repeat_penalty = repeat_penalty
|
15 |
self.allowed_repeat_line_length = allowed_repeat_line_length
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def generate_prompt(self, data_point):
|
18 |
reduced = False
|
@@ -55,6 +412,18 @@ class Prompter(object):
|
|
55 |
for oi, output in enumerate(outputs):
|
56 |
if self.prompt_type in [0, '0', 'plain']:
|
57 |
output = clean_response(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
else:
|
59 |
# find first instance of prereponse
|
60 |
# prompt sometimes has odd characters, that mutate length,
|
|
|
1 |
+
import time
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
|
5 |
+
class PromptType(Enum):
|
6 |
+
plain = 0
|
7 |
+
instruct = 1
|
8 |
+
quality = 2
|
9 |
+
human_bot = 3
|
10 |
+
dai_faq = 4
|
11 |
+
summarize = 5
|
12 |
+
simple_instruct = 6
|
13 |
+
instruct_vicuna = 7
|
14 |
+
instruct_with_end = 8
|
15 |
+
human_bot_orig = 9
|
16 |
+
prompt_answer = 10
|
17 |
+
open_assistant = 11
|
18 |
+
wizard_lm = 12
|
19 |
+
wizard_mega = 13
|
20 |
+
|
21 |
+
|
22 |
+
prompt_type_to_model_name = {
|
23 |
+
'plain': [
|
24 |
+
'EleutherAI/gpt-j-6B',
|
25 |
+
'EleutherAI/pythia-6.9b',
|
26 |
+
'EleutherAI/pythia-12b',
|
27 |
+
'EleutherAI/pythia-12b-deduped',
|
28 |
+
'EleutherAI/gpt-neox-20b',
|
29 |
+
'decapoda-research/llama-7b-hf',
|
30 |
+
'decapoda-research/llama-13b-hf',
|
31 |
+
'decapoda-research/llama-30b-hf',
|
32 |
+
'decapoda-research/llama-65b-hf',
|
33 |
+
'facebook/mbart-large-50-many-to-many-mmt',
|
34 |
+
'philschmid/bart-large-cnn-samsum',
|
35 |
+
'philschmid/flan-t5-base-samsum',
|
36 |
+
'gpt2',
|
37 |
+
'distilgpt2',
|
38 |
+
'mosaicml/mpt-7b-storywriter',
|
39 |
+
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
40 |
+
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
41 |
+
'gptj', # internally handles prompting
|
42 |
+
'llama', # internally handles prompting
|
43 |
+
],
|
44 |
+
'prompt_answer': [
|
45 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
46 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
47 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
48 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
49 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
50 |
+
],
|
51 |
+
'instruct': [],
|
52 |
+
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
53 |
+
'quality': [],
|
54 |
+
'human_bot': [
|
55 |
+
'h2oai/h2ogpt-oasst1-512-12b',
|
56 |
+
'h2oai/h2ogpt-oasst1-512-20b',
|
57 |
+
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
58 |
+
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
59 |
+
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
60 |
+
],
|
61 |
+
'dai_faq': [],
|
62 |
+
'summarize': [],
|
63 |
+
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
64 |
+
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
|
65 |
+
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
66 |
+
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
67 |
+
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
68 |
+
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
69 |
+
}
|
70 |
+
|
71 |
+
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
72 |
+
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
73 |
+
|
74 |
+
prompt_types_strings = []
|
75 |
+
for p in PromptType:
|
76 |
+
prompt_types_strings.extend([p.name])
|
77 |
+
|
78 |
+
prompt_types = []
|
79 |
+
for p in PromptType:
|
80 |
+
prompt_types.extend([p.name, p.value, str(p.value)])
|
81 |
+
|
82 |
+
|
83 |
+
def get_prompt(prompt_type, chat, context, reduced):
|
84 |
+
if prompt_type in [-1, "-1", "plain"]:
|
85 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
86 |
+
terminate_response = []
|
87 |
+
chat_sep = ''
|
88 |
+
humanstr = ''
|
89 |
+
botstr = ''
|
90 |
+
elif prompt_type == 'simple_instruct':
|
91 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
92 |
+
terminate_response = []
|
93 |
+
chat_sep = '\n'
|
94 |
+
humanstr = ''
|
95 |
+
botstr = ''
|
96 |
+
elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
|
97 |
+
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
|
98 |
+
chat and reduced) else ''
|
99 |
+
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
100 |
+
chat and reduced) else ''
|
101 |
+
|
102 |
+
PreInstruct = """
|
103 |
+
### Instruction:
|
104 |
+
"""
|
105 |
+
|
106 |
+
PreInput = """
|
107 |
+
### Input:
|
108 |
+
"""
|
109 |
+
|
110 |
+
PreResponse = """
|
111 |
+
### Response:
|
112 |
+
"""
|
113 |
+
if prompt_type in [7, "7", "instruct_with_end"]:
|
114 |
+
terminate_response = ['### End']
|
115 |
+
else:
|
116 |
+
terminate_response = None
|
117 |
+
chat_sep = '\n'
|
118 |
+
humanstr = PreInstruct
|
119 |
+
botstr = PreResponse
|
120 |
+
elif prompt_type in [1, "1", "quality"]:
|
121 |
+
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
|
122 |
+
chat and reduced) else ''
|
123 |
+
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
|
124 |
+
chat and reduced) else ''
|
125 |
+
|
126 |
+
PreInstruct = """
|
127 |
+
### Instruction:
|
128 |
+
"""
|
129 |
+
|
130 |
+
PreInput = """
|
131 |
+
### Input:
|
132 |
+
"""
|
133 |
+
|
134 |
+
PreResponse = """
|
135 |
+
### Response:
|
136 |
+
"""
|
137 |
+
terminate_response = None
|
138 |
+
chat_sep = '\n'
|
139 |
+
humanstr = PreInstruct # first thing human says
|
140 |
+
botstr = PreResponse # first thing bot says
|
141 |
+
elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
|
142 |
+
human = '<human>:'
|
143 |
+
bot = "<bot>:"
|
144 |
+
if reduced or context or prompt_type in [2, "2", "human_bot"]:
|
145 |
+
preprompt = ''
|
146 |
+
else:
|
147 |
+
cur_date = time.strftime('%Y-%m-%d')
|
148 |
+
cur_time = time.strftime('%H:%M:%S %p %Z')
|
149 |
+
|
150 |
+
PRE_PROMPT = """\
|
151 |
+
Current Date: {}
|
152 |
+
Current Time: {}
|
153 |
+
|
154 |
+
"""
|
155 |
+
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
156 |
+
start = human
|
157 |
+
promptB = promptA = '%s%s ' % (preprompt, start)
|
158 |
+
|
159 |
+
PreInstruct = ""
|
160 |
+
|
161 |
+
PreInput = None
|
162 |
+
|
163 |
+
if reduced:
|
164 |
+
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
165 |
+
PreResponse = bot + ' '
|
166 |
+
else:
|
167 |
+
# normally LLM adds space after this, because was how trained.
|
168 |
+
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
169 |
+
PreResponse = bot
|
170 |
+
|
171 |
+
terminate_response = [start, PreResponse]
|
172 |
+
chat_sep = '\n'
|
173 |
+
humanstr = human # tag before human talks
|
174 |
+
botstr = bot # tag before bot talks
|
175 |
+
elif prompt_type in [3, "3", "dai_faq"]:
|
176 |
+
promptA = ''
|
177 |
+
promptB = 'Answer the following Driverless AI question.\n'
|
178 |
+
|
179 |
+
PreInstruct = """
|
180 |
+
### Driverless AI frequently asked question:
|
181 |
+
"""
|
182 |
+
|
183 |
+
PreInput = None
|
184 |
+
|
185 |
+
PreResponse = """
|
186 |
+
### Driverless AI documentation answer:
|
187 |
+
"""
|
188 |
+
terminate_response = ['\n\n']
|
189 |
+
chat_sep = terminate_response
|
190 |
+
humanstr = PreInstruct
|
191 |
+
botstr = PreResponse
|
192 |
+
elif prompt_type in [5, "5", "summarize"]:
|
193 |
+
promptA = promptB = PreInput = ''
|
194 |
+
PreInstruct = '## Main Text\n\n'
|
195 |
+
PreResponse = '\n\n## Summary\n\n'
|
196 |
+
terminate_response = None
|
197 |
+
chat_sep = '\n'
|
198 |
+
humanstr = PreInstruct
|
199 |
+
botstr = PreResponse
|
200 |
+
elif prompt_type in [6, "6", "instruct_vicuna"]:
|
201 |
+
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
202 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
|
203 |
+
chat and reduced) else ''
|
204 |
+
|
205 |
+
PreInstruct = """
|
206 |
+
### Human:
|
207 |
+
"""
|
208 |
+
|
209 |
+
PreInput = None
|
210 |
+
|
211 |
+
PreResponse = """
|
212 |
+
### Assistant:
|
213 |
+
"""
|
214 |
+
terminate_response = [
|
215 |
+
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
216 |
+
chat_sep = '\n'
|
217 |
+
humanstr = PreInstruct
|
218 |
+
botstr = PreResponse
|
219 |
+
elif prompt_type in [10, "10", "prompt_answer"]:
|
220 |
+
preprompt = ''
|
221 |
+
prompt_tokens = "<|prompt|>"
|
222 |
+
answer_tokens = "<|answer|>"
|
223 |
+
start = prompt_tokens
|
224 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
225 |
+
PreInstruct = ""
|
226 |
+
PreInput = None
|
227 |
+
PreResponse = answer_tokens
|
228 |
+
eos = '<|endoftext|>' # neox eos
|
229 |
+
terminate_response = [start, PreResponse, eos]
|
230 |
+
chat_sep = eos
|
231 |
+
humanstr = prompt_tokens
|
232 |
+
botstr = answer_tokens
|
233 |
+
elif prompt_type in [11, "11", "open_assistant"]:
|
234 |
+
# From added_tokens.json
|
235 |
+
preprompt = ''
|
236 |
+
prompt_tokens = "<|prompter|>"
|
237 |
+
answer_tokens = "<|assistant|>"
|
238 |
+
start = prompt_tokens
|
239 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
240 |
+
PreInstruct = ""
|
241 |
+
PreInput = None
|
242 |
+
PreResponse = answer_tokens
|
243 |
+
pend = "<|prefix_end|>"
|
244 |
+
eos = "</s>"
|
245 |
+
terminate_response = [start, PreResponse, pend, eos]
|
246 |
+
chat_sep = eos
|
247 |
+
humanstr = prompt_tokens
|
248 |
+
botstr = answer_tokens
|
249 |
+
elif prompt_type in [12, "12", "wizard_lm"]:
|
250 |
+
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
251 |
+
preprompt = ''
|
252 |
+
start = ''
|
253 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
254 |
+
PreInstruct = ""
|
255 |
+
PreInput = None
|
256 |
+
PreResponse = "\n\n### Response"
|
257 |
+
eos = "</s>"
|
258 |
+
terminate_response = [PreResponse, eos]
|
259 |
+
chat_sep = eos
|
260 |
+
humanstr = promptA
|
261 |
+
botstr = PreResponse
|
262 |
+
elif prompt_type in [13, "13", "wizard_mega"]:
|
263 |
+
preprompt = ''
|
264 |
+
start = ''
|
265 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
266 |
+
PreInstruct = """
|
267 |
+
### Instruction:
|
268 |
+
"""
|
269 |
+
PreInput = None
|
270 |
+
PreResponse = """
|
271 |
+
### Assistant:
|
272 |
+
"""
|
273 |
+
terminate_response = [PreResponse]
|
274 |
+
chat_sep = '\n'
|
275 |
+
humanstr = PreInstruct
|
276 |
+
botstr = PreResponse
|
277 |
+
else:
|
278 |
+
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
279 |
+
|
280 |
+
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
|
281 |
+
|
282 |
+
|
283 |
+
def generate_prompt(data_point, prompt_type, chat, reduced):
|
284 |
+
context = data_point.get('context')
|
285 |
+
if context is None:
|
286 |
+
context = ''
|
287 |
+
instruction = data_point.get('instruction')
|
288 |
+
input = data_point.get('input')
|
289 |
+
output = data_point.get('output')
|
290 |
+
prompt_type = data_point.get('prompt_type', prompt_type)
|
291 |
+
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
292 |
+
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
293 |
+
terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, chat, context, reduced)
|
294 |
+
|
295 |
+
prompt = context if not reduced else ''
|
296 |
+
|
297 |
+
if input and promptA:
|
298 |
+
prompt += f"""{promptA}"""
|
299 |
+
elif promptB:
|
300 |
+
prompt += f"""{promptB}"""
|
301 |
+
|
302 |
+
if instruction and PreInstruct is not None and input and PreInput is not None:
|
303 |
+
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
304 |
+
prompt = inject_newline(prompt_type, prompt)
|
305 |
+
elif instruction and input and PreInstruct is None and PreInput is not None:
|
306 |
+
prompt += f"""{PreInput}{instruction}
|
307 |
+
{input}"""
|
308 |
+
prompt = inject_newline(prompt_type, prompt)
|
309 |
+
elif input and instruction and PreInput is None and PreInstruct is not None:
|
310 |
+
prompt += f"""{PreInstruct}{instruction}
|
311 |
+
{input}"""
|
312 |
+
prompt = inject_newline(prompt_type, prompt)
|
313 |
+
elif instruction and PreInstruct is not None:
|
314 |
+
prompt += f"""{PreInstruct}{instruction}"""
|
315 |
+
prompt = inject_newline(prompt_type, prompt)
|
316 |
+
elif input and PreInput is not None:
|
317 |
+
prompt += f"""{PreInput}{input}"""
|
318 |
+
prompt = inject_newline(prompt_type, prompt)
|
319 |
+
elif input and instruction and PreInput is not None:
|
320 |
+
prompt += f"""{PreInput}{instruction}{input}"""
|
321 |
+
prompt = inject_newline(prompt_type, prompt)
|
322 |
+
elif input and instruction and PreInstruct is not None:
|
323 |
+
prompt += f"""{PreInstruct}{instruction}{input}"""
|
324 |
+
prompt = inject_newline(prompt_type, prompt)
|
325 |
+
elif input and instruction:
|
326 |
+
# i.e. for simple_instruct
|
327 |
+
prompt += f"""{instruction}: {input}"""
|
328 |
+
prompt = inject_newline(prompt_type, prompt)
|
329 |
+
elif input:
|
330 |
+
prompt += f"""{input}"""
|
331 |
+
prompt = inject_newline(prompt_type, prompt)
|
332 |
+
elif instruction:
|
333 |
+
prompt += f"""{instruction}"""
|
334 |
+
prompt = inject_newline(prompt_type, prompt)
|
335 |
+
|
336 |
+
if PreResponse is not None:
|
337 |
+
prompt += f"""{PreResponse}"""
|
338 |
+
pre_response = PreResponse # Don't use strip
|
339 |
+
else:
|
340 |
+
pre_response = ''
|
341 |
+
|
342 |
+
if output:
|
343 |
+
prompt += f"""{output}"""
|
344 |
+
|
345 |
+
return prompt, pre_response, terminate_response, chat_sep
|
346 |
+
|
347 |
+
|
348 |
+
def inject_newline(prompt_type, prompt):
|
349 |
+
if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
|
350 |
+
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
351 |
+
prompt += '\n'
|
352 |
+
return prompt
|
353 |
|
354 |
|
355 |
class Prompter(object):
|
|
|
364 |
self.stream_output = stream_output
|
365 |
self.repeat_penalty = repeat_penalty
|
366 |
self.allowed_repeat_line_length = allowed_repeat_line_length
|
367 |
+
self.prompt = None
|
368 |
+
context = "" # not for chat context
|
369 |
+
reduced = False # not for chat context
|
370 |
+
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
371 |
+
self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
|
372 |
+
get_prompt(prompt_type, chat, context, reduced)
|
373 |
|
374 |
def generate_prompt(self, data_point):
|
375 |
reduced = False
|
|
|
412 |
for oi, output in enumerate(outputs):
|
413 |
if self.prompt_type in [0, '0', 'plain']:
|
414 |
output = clean_response(output)
|
415 |
+
elif prompt is None:
|
416 |
+
# then use most basic parsing like pipeline
|
417 |
+
if self.botstr in output:
|
418 |
+
if self.humanstr:
|
419 |
+
output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
|
420 |
+
else:
|
421 |
+
# i.e. use after bot but only up to next bot
|
422 |
+
output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
|
423 |
+
else:
|
424 |
+
# output = clean_response(output.strip())
|
425 |
+
# assume just not printed yet
|
426 |
+
output = ""
|
427 |
else:
|
428 |
# find first instance of prereponse
|
429 |
# prompt sometimes has odd characters, that mutate length,
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
datasets==2.12.0
|
3 |
sentencepiece==0.1.97
|
4 |
accelerate==0.18.0
|
5 |
-
gradio==3.
|
6 |
huggingface_hub==0.14.1
|
7 |
appdirs==1.4.4
|
8 |
fire==0.5.0
|
@@ -35,7 +35,7 @@ tensorboard==2.12.1
|
|
35 |
neptune==1.1.1
|
36 |
|
37 |
# for gradio client
|
38 |
-
gradio_client==0.
|
39 |
beautifulsoup4==4.12.2
|
40 |
markdown==3.4.1
|
41 |
|
@@ -45,7 +45,58 @@ pytest-xdist==3.2.1
|
|
45 |
nltk==3.8.1
|
46 |
textstat==0.7.3
|
47 |
pandoc==2.3
|
48 |
-
pypandoc==1.11
|
|
|
49 |
openpyxl==3.1.2
|
50 |
lm_dataformat==0.0.20
|
51 |
bioc==2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
datasets==2.12.0
|
3 |
sentencepiece==0.1.97
|
4 |
accelerate==0.18.0
|
5 |
+
gradio==3.31.0
|
6 |
huggingface_hub==0.14.1
|
7 |
appdirs==1.4.4
|
8 |
fire==0.5.0
|
|
|
35 |
neptune==1.1.1
|
36 |
|
37 |
# for gradio client
|
38 |
+
gradio_client==0.2.5
|
39 |
beautifulsoup4==4.12.2
|
40 |
markdown==3.4.1
|
41 |
|
|
|
45 |
nltk==3.8.1
|
46 |
textstat==0.7.3
|
47 |
pandoc==2.3
|
48 |
+
#pypandoc==1.11
|
49 |
+
pypandoc_binary==1.11
|
50 |
openpyxl==3.1.2
|
51 |
lm_dataformat==0.0.20
|
52 |
bioc==2.0
|
53 |
+
# To install with constraints
|
54 |
+
# grep -v '#\|peft' requirements.txt > req_constraints.txt ; pip install -r requirements_optional_langchain.txt -c req_constraints.txt
|
55 |
+
|
56 |
+
# optional for chat with PDF
|
57 |
+
langchain==0.0.178
|
58 |
+
pypdf==3.8.1
|
59 |
+
tiktoken==0.3.3
|
60 |
+
# avoid textract, requires old six
|
61 |
+
#textract==1.6.5
|
62 |
+
# choose:
|
63 |
+
#faiss-cpu
|
64 |
+
faiss-gpu==1.7.2
|
65 |
+
|
66 |
+
# for HF embeddings
|
67 |
+
sentence_transformers==2.2.2
|
68 |
+
# for OpenAI embeddings (requires key)
|
69 |
+
openai==0.27.6
|
70 |
+
|
71 |
+
# local vector db
|
72 |
+
chromadb==0.3.23
|
73 |
+
# server vector db
|
74 |
+
#pymilvus==2.2.8
|
75 |
+
|
76 |
+
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
77 |
+
# unstructured==0.6.6
|
78 |
+
|
79 |
+
# strong support for images
|
80 |
+
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
81 |
+
unstructured[local-inference]==0.6.6
|
82 |
+
#pdf2image==1.16.3
|
83 |
+
#pytesseract==0.3.10
|
84 |
+
pillow
|
85 |
+
|
86 |
+
pdfminer.six==20221105
|
87 |
+
urllib3==1.26.6
|
88 |
+
requests_file==1.5.1
|
89 |
+
|
90 |
+
#pdf2image==1.16.3
|
91 |
+
#pytesseract==0.3.10
|
92 |
+
tabulate==0.9.0
|
93 |
+
# FYI pandoc already part of requirements.txt
|
94 |
+
|
95 |
+
jq==1.4.1
|
96 |
+
|
97 |
+
# to check licenses
|
98 |
+
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
99 |
+
pip-licenses==4.3.0
|
100 |
+
gpt4all==0.2.3
|
101 |
+
llama-cpp-python==0.1.54
|
102 |
+
python-dotenv==1.0.0
|
utils.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import functools
|
|
|
2 |
import os
|
3 |
import gc
|
4 |
import pathlib
|
@@ -12,6 +14,9 @@ import traceback
|
|
12 |
import zipfile
|
13 |
from datetime import datetime
|
14 |
import filelock
|
|
|
|
|
|
|
15 |
import numpy as np
|
16 |
import pandas as pd
|
17 |
|
@@ -53,7 +58,11 @@ def clear_torch_cache():
|
|
53 |
|
54 |
|
55 |
def ping():
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
def get_torch_allocated():
|
@@ -61,6 +70,16 @@ def get_torch_allocated():
|
|
61 |
return torch.cuda.memory_allocated()
|
62 |
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def system_info():
|
65 |
import psutil
|
66 |
|
@@ -111,21 +130,26 @@ def system_info_print():
|
|
111 |
return "Error: %s" % str(e)
|
112 |
|
113 |
|
114 |
-
def zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
115 |
try:
|
116 |
return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
|
117 |
except Exception as e:
|
118 |
traceback.print_exc()
|
119 |
print('Exception in zipping: %s' % str(e))
|
|
|
|
|
120 |
|
121 |
|
122 |
def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
|
|
|
|
123 |
if zip_file is None:
|
124 |
datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
|
125 |
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
126 |
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
127 |
assert root_dirs is not None
|
128 |
-
|
|
|
129 |
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
130 |
for root_dir in root_dirs:
|
131 |
if root_dir is None:
|
@@ -237,6 +261,7 @@ class NullContext(threading.local):
|
|
237 |
Used as a stand-in if a particular block of code is only sometimes
|
238 |
used with a normal context manager:
|
239 |
"""
|
|
|
240 |
def __init__(self, *args, **kwargs):
|
241 |
pass
|
242 |
|
@@ -270,16 +295,18 @@ class ThreadException(Exception):
|
|
270 |
class EThread(threading.Thread):
|
271 |
# Function that raises the custom exception
|
272 |
def __init__(self, group=None, target=None, name=None,
|
273 |
-
args=(), kwargs=None, *, daemon=None, bucket=None):
|
274 |
self.bucket = bucket
|
275 |
-
self.streamer =
|
276 |
self.exc = None
|
|
|
277 |
super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
|
278 |
|
279 |
def run(self):
|
280 |
# Variable that stores the exception, if raised by someFunction
|
281 |
try:
|
282 |
-
|
|
|
283 |
except BaseException as e:
|
284 |
print("thread exception: %s" % str(sys.exc_info()))
|
285 |
self.bucket.put(sys.exc_info())
|
@@ -287,6 +314,10 @@ class EThread(threading.Thread):
|
|
287 |
if self.streamer:
|
288 |
print("make stop: %s" % str(sys.exc_info()), flush=True)
|
289 |
self.streamer.do_stop = True
|
|
|
|
|
|
|
|
|
290 |
|
291 |
def join(self, timeout=None):
|
292 |
threading.Thread.join(self)
|
@@ -295,3 +326,443 @@ class EThread(threading.Thread):
|
|
295 |
# if any was caught
|
296 |
if self.exc:
|
297 |
raise self.exc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
import functools
|
3 |
+
import hashlib
|
4 |
import os
|
5 |
import gc
|
6 |
import pathlib
|
|
|
14 |
import zipfile
|
15 |
from datetime import datetime
|
16 |
import filelock
|
17 |
+
import requests, uuid
|
18 |
+
from typing import Tuple, Callable, Dict
|
19 |
+
from concurrent.futures import ProcessPoolExecutor
|
20 |
import numpy as np
|
21 |
import pandas as pd
|
22 |
|
|
|
58 |
|
59 |
|
60 |
def ping():
|
61 |
+
try:
|
62 |
+
print('Ping: %s' % str(datetime.now()), flush=True)
|
63 |
+
except AttributeError:
|
64 |
+
# some programs wrap print and will fail with flush passed
|
65 |
+
pass
|
66 |
|
67 |
|
68 |
def get_torch_allocated():
|
|
|
70 |
return torch.cuda.memory_allocated()
|
71 |
|
72 |
|
73 |
+
def get_device():
|
74 |
+
import torch
|
75 |
+
if torch.cuda.is_available():
|
76 |
+
device = "cuda"
|
77 |
+
else:
|
78 |
+
device = "cpu"
|
79 |
+
|
80 |
+
return device
|
81 |
+
|
82 |
+
|
83 |
def system_info():
|
84 |
import psutil
|
85 |
|
|
|
130 |
return "Error: %s" % str(e)
|
131 |
|
132 |
|
133 |
+
def zip_data(root_dirs=None, zip_file=None, base_dir='./', fail_any_exception=False):
|
134 |
try:
|
135 |
return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
|
136 |
except Exception as e:
|
137 |
traceback.print_exc()
|
138 |
print('Exception in zipping: %s' % str(e))
|
139 |
+
if not fail_any_exception:
|
140 |
+
raise
|
141 |
|
142 |
|
143 |
def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
144 |
+
if isinstance(root_dirs, str):
|
145 |
+
root_dirs = [root_dirs]
|
146 |
if zip_file is None:
|
147 |
datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
|
148 |
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
149 |
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
150 |
assert root_dirs is not None
|
151 |
+
if not os.path.isdir(os.path.dirname(zip_file)):
|
152 |
+
os.makedirs(os.path.dirname(zip_file), exist_ok=True)
|
153 |
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
154 |
for root_dir in root_dirs:
|
155 |
if root_dir is None:
|
|
|
261 |
Used as a stand-in if a particular block of code is only sometimes
|
262 |
used with a normal context manager:
|
263 |
"""
|
264 |
+
|
265 |
def __init__(self, *args, **kwargs):
|
266 |
pass
|
267 |
|
|
|
295 |
class EThread(threading.Thread):
|
296 |
# Function that raises the custom exception
|
297 |
def __init__(self, group=None, target=None, name=None,
|
298 |
+
args=(), kwargs=None, *, daemon=None, streamer=None, bucket=None):
|
299 |
self.bucket = bucket
|
300 |
+
self.streamer = streamer
|
301 |
self.exc = None
|
302 |
+
self._return = None
|
303 |
super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
|
304 |
|
305 |
def run(self):
|
306 |
# Variable that stores the exception, if raised by someFunction
|
307 |
try:
|
308 |
+
if self._target is not None:
|
309 |
+
self._return = self._target(*self._args, **self._kwargs)
|
310 |
except BaseException as e:
|
311 |
print("thread exception: %s" % str(sys.exc_info()))
|
312 |
self.bucket.put(sys.exc_info())
|
|
|
314 |
if self.streamer:
|
315 |
print("make stop: %s" % str(sys.exc_info()), flush=True)
|
316 |
self.streamer.do_stop = True
|
317 |
+
finally:
|
318 |
+
# Avoid a refcycle if the thread is running a function with
|
319 |
+
# an argument that has a member that points to the thread.
|
320 |
+
del self._target, self._args, self._kwargs
|
321 |
|
322 |
def join(self, timeout=None):
|
323 |
threading.Thread.join(self)
|
|
|
326 |
# if any was caught
|
327 |
if self.exc:
|
328 |
raise self.exc
|
329 |
+
return self._return
|
330 |
+
|
331 |
+
|
332 |
+
def import_matplotlib():
|
333 |
+
import matplotlib
|
334 |
+
matplotlib.use('agg')
|
335 |
+
# KEEP THESE HERE! START
|
336 |
+
import matplotlib.pyplot as plt
|
337 |
+
import pandas as pd
|
338 |
+
# to avoid dlopen deadlock in fork
|
339 |
+
import pandas.core.computation.expressions as pd_expressions
|
340 |
+
import pandas._libs.groupby as pd_libgroupby
|
341 |
+
import pandas._libs.reduction as pd_libreduction
|
342 |
+
import pandas.core.algorithms as pd_algorithms
|
343 |
+
import pandas.core.common as pd_com
|
344 |
+
import numpy as np
|
345 |
+
# KEEP THESE HERE! END
|
346 |
+
|
347 |
+
|
348 |
+
def get_sha(value):
|
349 |
+
return hashlib.md5(str(value).encode('utf-8')).hexdigest()
|
350 |
+
|
351 |
+
|
352 |
+
def sanitize_filename(name):
|
353 |
+
"""
|
354 |
+
Sanitize file *base* names.
|
355 |
+
:param name: name to sanitize
|
356 |
+
:return:
|
357 |
+
"""
|
358 |
+
bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^']
|
359 |
+
for char in bad_chars:
|
360 |
+
name = name.replace(char, "_")
|
361 |
+
|
362 |
+
length = len(name)
|
363 |
+
file_length_limit = 250 # bit smaller than 256 for safety
|
364 |
+
sha_length = 32
|
365 |
+
real_length_limit = file_length_limit - (sha_length + 2)
|
366 |
+
if length > file_length_limit:
|
367 |
+
sha = get_sha(name)
|
368 |
+
half_real_length_limit = max(1, int(real_length_limit / 2))
|
369 |
+
name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length]
|
370 |
+
|
371 |
+
return name
|
372 |
+
|
373 |
+
|
374 |
+
def shutil_rmtree_simple(*args, **kwargs):
|
375 |
+
path = args[0]
|
376 |
+
assert not os.path.samefile(path, "./tmp"), "Should not be trying to remove entire data directory: %s" % str(path)
|
377 |
+
# print("Removing path %s" % args[0]) # for debugging
|
378 |
+
return shutil.rmtree(*args, **kwargs)
|
379 |
+
|
380 |
+
|
381 |
+
def remove_simple(path: str):
|
382 |
+
try:
|
383 |
+
if path is not None and os.path.exists(path):
|
384 |
+
if os.path.isdir(path):
|
385 |
+
shutil_rmtree_simple(path, ignore_errors=True)
|
386 |
+
else:
|
387 |
+
with contextlib.suppress(FileNotFoundError):
|
388 |
+
os.remove(path)
|
389 |
+
except:
|
390 |
+
pass
|
391 |
+
|
392 |
+
|
393 |
+
def makedirs(path, exist_ok=True):
|
394 |
+
"""
|
395 |
+
Avoid some inefficiency in os.makedirs()
|
396 |
+
:param path:
|
397 |
+
:param exist_ok:
|
398 |
+
:return:
|
399 |
+
"""
|
400 |
+
if os.path.isdir(path) and os.path.exists(path):
|
401 |
+
assert exist_ok, "Path already exists"
|
402 |
+
return path
|
403 |
+
os.makedirs(path, exist_ok=exist_ok)
|
404 |
+
|
405 |
+
|
406 |
+
def atomic_move_simple(src, dst):
|
407 |
+
try:
|
408 |
+
shutil.move(src, dst)
|
409 |
+
except (shutil.Error, FileExistsError):
|
410 |
+
pass
|
411 |
+
remove_simple(src)
|
412 |
+
|
413 |
+
|
414 |
+
def download_simple(url, dest=None, print_func=None):
|
415 |
+
if print_func is not None:
|
416 |
+
print_func("BEGIN get url %s" % str(url))
|
417 |
+
if url.startswith("file://"):
|
418 |
+
from requests_file import FileAdapter
|
419 |
+
s = requests.Session()
|
420 |
+
s.mount('file://', FileAdapter())
|
421 |
+
url_data = s.get(url, stream=True)
|
422 |
+
else:
|
423 |
+
url_data = requests.get(url, stream=True)
|
424 |
+
if dest is None:
|
425 |
+
dest = os.path.basename(url)
|
426 |
+
if url_data.status_code != requests.codes.ok:
|
427 |
+
msg = "Cannot get url %s, code: %s, reason: %s" % (
|
428 |
+
str(url),
|
429 |
+
str(url_data.status_code),
|
430 |
+
str(url_data.reason),
|
431 |
+
)
|
432 |
+
raise requests.exceptions.RequestException(msg)
|
433 |
+
url_data.raw.decode_content = True
|
434 |
+
makedirs(os.path.dirname(dest), exist_ok=True)
|
435 |
+
uuid_tmp = str(uuid.uuid4())[:6]
|
436 |
+
dest_tmp = dest + "_dl_" + uuid_tmp + ".tmp"
|
437 |
+
with open(dest_tmp, "wb") as f:
|
438 |
+
shutil.copyfileobj(url_data.raw, f)
|
439 |
+
atomic_move_simple(dest_tmp, dest)
|
440 |
+
if print_func is not None:
|
441 |
+
print_func("END get url %s" % str(url))
|
442 |
+
|
443 |
+
|
444 |
+
def download(url, dest=None, dest_path=None):
|
445 |
+
if dest_path is not None:
|
446 |
+
dest = os.path.join(dest_path, os.path.basename(url))
|
447 |
+
if os.path.isfile(dest):
|
448 |
+
print("already downloaded %s -> %s" % (url, dest))
|
449 |
+
return dest
|
450 |
+
elif dest is not None:
|
451 |
+
if os.path.exists(dest):
|
452 |
+
print("already downloaded %s -> %s" % (url, dest))
|
453 |
+
return dest
|
454 |
+
else:
|
455 |
+
uuid_tmp = "dl2_" + str(uuid.uuid4())[:6]
|
456 |
+
dest = uuid_tmp + os.path.basename(url)
|
457 |
+
|
458 |
+
print("downloading %s to %s" % (url, dest))
|
459 |
+
|
460 |
+
if url.startswith("file://"):
|
461 |
+
from requests_file import FileAdapter
|
462 |
+
s = requests.Session()
|
463 |
+
s.mount('file://', FileAdapter())
|
464 |
+
url_data = s.get(url, stream=True)
|
465 |
+
else:
|
466 |
+
url_data = requests.get(url, stream=True)
|
467 |
+
|
468 |
+
if url_data.status_code != requests.codes.ok:
|
469 |
+
msg = "Cannot get url %s, code: %s, reason: %s" % (
|
470 |
+
str(url), str(url_data.status_code), str(url_data.reason))
|
471 |
+
raise requests.exceptions.RequestException(msg)
|
472 |
+
url_data.raw.decode_content = True
|
473 |
+
dirname = os.path.dirname(dest)
|
474 |
+
if dirname != "" and not os.path.isdir(dirname):
|
475 |
+
makedirs(os.path.dirname(dest), exist_ok=True)
|
476 |
+
uuid_tmp = "dl3_" + str(uuid.uuid4())[:6]
|
477 |
+
dest_tmp = dest + "_" + uuid_tmp + ".tmp"
|
478 |
+
with open(dest_tmp, 'wb') as f:
|
479 |
+
shutil.copyfileobj(url_data.raw, f)
|
480 |
+
try:
|
481 |
+
shutil.move(dest_tmp, dest)
|
482 |
+
except FileExistsError:
|
483 |
+
pass
|
484 |
+
remove_simple(dest_tmp)
|
485 |
+
return dest
|
486 |
+
|
487 |
+
|
488 |
+
def get_url(x, from_str=False, short_name=False):
|
489 |
+
if not from_str:
|
490 |
+
source = x.metadata['source']
|
491 |
+
else:
|
492 |
+
source = x
|
493 |
+
if short_name:
|
494 |
+
source_name = get_short_name(source)
|
495 |
+
else:
|
496 |
+
source_name = source
|
497 |
+
if source.startswith('http://') or source.startswith('https://'):
|
498 |
+
return """<a href="%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
|
499 |
+
source, source_name)
|
500 |
+
else:
|
501 |
+
return """<a href="file/%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
|
502 |
+
source, source_name)
|
503 |
+
|
504 |
+
|
505 |
+
def get_short_name(name, maxl=50):
|
506 |
+
if name is None:
|
507 |
+
return ''
|
508 |
+
length = len(name)
|
509 |
+
if length > maxl:
|
510 |
+
allow_length = maxl - 3
|
511 |
+
half_allowed = max(1, int(allow_length / 2))
|
512 |
+
name = name[0:half_allowed] + "..." + name[length - half_allowed:length]
|
513 |
+
return name
|
514 |
+
|
515 |
+
|
516 |
+
def cuda_vis_check(total_gpus):
|
517 |
+
"""Helper function to count GPUs by environment variable
|
518 |
+
Stolen from Jon's h2o4gpu utils
|
519 |
+
"""
|
520 |
+
cudavis = os.getenv("CUDA_VISIBLE_DEVICES")
|
521 |
+
which_gpus = []
|
522 |
+
if cudavis is not None:
|
523 |
+
# prune away white-space, non-numerics,
|
524 |
+
# except commas for simple checking
|
525 |
+
cudavis = "".join(cudavis.split())
|
526 |
+
import re
|
527 |
+
cudavis = re.sub("[^0-9,]", "", cudavis)
|
528 |
+
|
529 |
+
lencudavis = len(cudavis)
|
530 |
+
if lencudavis == 0:
|
531 |
+
total_gpus = 0
|
532 |
+
else:
|
533 |
+
total_gpus = min(
|
534 |
+
total_gpus,
|
535 |
+
os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1)
|
536 |
+
which_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(",")
|
537 |
+
which_gpus = [int(x) for x in which_gpus]
|
538 |
+
else:
|
539 |
+
which_gpus = list(range(0, total_gpus))
|
540 |
+
|
541 |
+
return total_gpus, which_gpus
|
542 |
+
|
543 |
+
|
544 |
+
def get_ngpus_vis(raise_if_exception=True):
|
545 |
+
ngpus_vis1 = 0
|
546 |
+
|
547 |
+
shell = False
|
548 |
+
if shell:
|
549 |
+
cmd = "nvidia-smi -L 2> /dev/null"
|
550 |
+
else:
|
551 |
+
cmd = ["nvidia-smi", "-L"]
|
552 |
+
|
553 |
+
try:
|
554 |
+
timeout = 5 * 3
|
555 |
+
o = subprocess.check_output(cmd, shell=shell, timeout=timeout)
|
556 |
+
lines = o.decode("utf-8").splitlines()
|
557 |
+
ngpus_vis1 = 0
|
558 |
+
for line in lines:
|
559 |
+
if 'Failed to initialize NVML' not in line:
|
560 |
+
ngpus_vis1 += 1
|
561 |
+
except (FileNotFoundError, subprocess.CalledProcessError, OSError):
|
562 |
+
# GPU systems might not have nvidia-smi, so can't fail
|
563 |
+
pass
|
564 |
+
except subprocess.TimeoutExpired as e:
|
565 |
+
print('Failed get_ngpus_vis: %s' % str(e))
|
566 |
+
if raise_if_exception:
|
567 |
+
raise
|
568 |
+
|
569 |
+
ngpus_vis1, which_gpus = cuda_vis_check(ngpus_vis1)
|
570 |
+
return ngpus_vis1
|
571 |
+
|
572 |
+
|
573 |
+
def get_mem_gpus(raise_if_exception=True, ngpus=None):
|
574 |
+
totalmem_gpus1 = 0
|
575 |
+
usedmem_gpus1 = 0
|
576 |
+
freemem_gpus1 = 0
|
577 |
+
|
578 |
+
if ngpus == 0:
|
579 |
+
return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
|
580 |
+
|
581 |
+
try:
|
582 |
+
cmd = "nvidia-smi -q 2> /dev/null | grep -A 3 'FB Memory Usage'"
|
583 |
+
o = subprocess.check_output(cmd, shell=True, timeout=15)
|
584 |
+
lines = o.decode("utf-8").splitlines()
|
585 |
+
for line in lines:
|
586 |
+
if 'Total' in line:
|
587 |
+
totalmem_gpus1 += int(line.split()[2]) * 1024 ** 2
|
588 |
+
if 'Used' in line:
|
589 |
+
usedmem_gpus1 += int(line.split()[2]) * 1024 ** 2
|
590 |
+
if 'Free' in line:
|
591 |
+
freemem_gpus1 += int(line.split()[2]) * 1024 ** 2
|
592 |
+
except (FileNotFoundError, subprocess.CalledProcessError, OSError):
|
593 |
+
# GPU systems might not have nvidia-smi, so can't fail
|
594 |
+
pass
|
595 |
+
except subprocess.TimeoutExpired as e:
|
596 |
+
print('Failed get_mem_gpus: %s' % str(e))
|
597 |
+
if raise_if_exception:
|
598 |
+
raise
|
599 |
+
|
600 |
+
return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
|
601 |
+
|
602 |
+
|
603 |
+
class ForkContext(threading.local):
|
604 |
+
"""
|
605 |
+
Set context for forking
|
606 |
+
Ensures state is returned once done
|
607 |
+
"""
|
608 |
+
|
609 |
+
def __init__(self, args=None, kwargs=None, forkdata_capable=True):
|
610 |
+
"""
|
611 |
+
:param args:
|
612 |
+
:param kwargs:
|
613 |
+
:param forkdata_capable: whether fork is forkdata capable and will use copy-on-write forking of args/kwargs
|
614 |
+
"""
|
615 |
+
self.forkdata_capable = forkdata_capable
|
616 |
+
if self.forkdata_capable:
|
617 |
+
self.has_args = args is not None
|
618 |
+
self.has_kwargs = kwargs is not None
|
619 |
+
forkdatacontext.args = args
|
620 |
+
forkdatacontext.kwargs = kwargs
|
621 |
+
else:
|
622 |
+
self.has_args = False
|
623 |
+
self.has_kwargs = False
|
624 |
+
|
625 |
+
def __enter__(self):
|
626 |
+
try:
|
627 |
+
# flush all outputs so doesn't happen during fork -- don't print/log inside ForkContext contexts!
|
628 |
+
sys.stdout.flush()
|
629 |
+
sys.stderr.flush()
|
630 |
+
except BaseException as e:
|
631 |
+
# exit not called if exception, and don't want to leave forkdatacontext filled in that case
|
632 |
+
print("ForkContext failure on enter: %s" % str(e))
|
633 |
+
self.finally_act()
|
634 |
+
raise
|
635 |
+
return self
|
636 |
+
|
637 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
638 |
+
self.finally_act()
|
639 |
+
|
640 |
+
def finally_act(self):
|
641 |
+
"""
|
642 |
+
Done when exception hit or exit is reached in context
|
643 |
+
first reset forkdatacontext as crucial to have reset even if later 2 calls fail
|
644 |
+
:return: None
|
645 |
+
"""
|
646 |
+
if self.forkdata_capable and (self.has_args or self.has_kwargs):
|
647 |
+
forkdatacontext._reset()
|
648 |
+
|
649 |
+
|
650 |
+
class _ForkDataContext(threading.local):
|
651 |
+
def __init__(
|
652 |
+
self,
|
653 |
+
args=None,
|
654 |
+
kwargs=None,
|
655 |
+
):
|
656 |
+
"""
|
657 |
+
Global context for fork to carry data to subprocess instead of relying upon copy/pickle/serialization
|
658 |
+
|
659 |
+
:param args: args
|
660 |
+
:param kwargs: kwargs
|
661 |
+
"""
|
662 |
+
assert isinstance(args, (tuple, type(None)))
|
663 |
+
assert isinstance(kwargs, (dict, type(None)))
|
664 |
+
self.__args = args
|
665 |
+
self.__kwargs = kwargs
|
666 |
+
|
667 |
+
@property
|
668 |
+
def args(self) -> Tuple:
|
669 |
+
"""returns args"""
|
670 |
+
return self.__args
|
671 |
+
|
672 |
+
@args.setter
|
673 |
+
def args(self, args):
|
674 |
+
if self.__args is not None:
|
675 |
+
raise AttributeError(
|
676 |
+
"args cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
|
677 |
+
)
|
678 |
+
|
679 |
+
self.__args = args
|
680 |
+
|
681 |
+
@property
|
682 |
+
def kwargs(self) -> Dict:
|
683 |
+
"""returns kwargs"""
|
684 |
+
return self.__kwargs
|
685 |
+
|
686 |
+
@kwargs.setter
|
687 |
+
def kwargs(self, kwargs):
|
688 |
+
if self.__kwargs is not None:
|
689 |
+
raise AttributeError(
|
690 |
+
"kwargs cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
|
691 |
+
)
|
692 |
+
|
693 |
+
self.__kwargs = kwargs
|
694 |
+
|
695 |
+
def _reset(self):
|
696 |
+
"""Reset fork arg-kwarg context to default values"""
|
697 |
+
self.__args = None
|
698 |
+
self.__kwargs = None
|
699 |
+
|
700 |
+
def get_args_kwargs(self, func, args, kwargs) -> Tuple[Callable, Tuple, Dict]:
|
701 |
+
if self.__args:
|
702 |
+
args = self.__args[1:]
|
703 |
+
if not func:
|
704 |
+
assert len(self.__args) > 0, "if have no func, must have in args"
|
705 |
+
func = self.__args[0] # should always be there
|
706 |
+
if self.__kwargs:
|
707 |
+
kwargs = self.__kwargs
|
708 |
+
try:
|
709 |
+
return func, args, kwargs
|
710 |
+
finally:
|
711 |
+
forkdatacontext._reset()
|
712 |
+
|
713 |
+
@staticmethod
|
714 |
+
def get_args_kwargs_for_traced_func(func, args, kwargs):
|
715 |
+
"""
|
716 |
+
Return args/kwargs out of forkdatacontext when using copy-on-write way of passing args/kwargs
|
717 |
+
:param func: actual function ran by _traced_func, which itself is directly what mppool treats as function
|
718 |
+
:param args:
|
719 |
+
:param kwargs:
|
720 |
+
:return: func, args, kwargs from forkdatacontext if used, else originals
|
721 |
+
"""
|
722 |
+
# first 3 lines are debug
|
723 |
+
func_was_None = func is None
|
724 |
+
args_was_None_or_empty = args is None or len(args) == 0
|
725 |
+
kwargs_was_None_or_empty = kwargs is None or len(kwargs) == 0
|
726 |
+
|
727 |
+
forkdatacontext_args_was_None = forkdatacontext.args is None
|
728 |
+
forkdatacontext_kwargs_was_None = forkdatacontext.kwargs is None
|
729 |
+
func, args, kwargs = forkdatacontext.get_args_kwargs(func, args, kwargs)
|
730 |
+
using_forkdatacontext = func_was_None and func is not None # pulled func out of forkdatacontext.__args[0]
|
731 |
+
assert forkdatacontext.args is None, "forkdatacontext.args should be None after get_args_kwargs"
|
732 |
+
assert forkdatacontext.kwargs is None, "forkdatacontext.kwargs should be None after get_args_kwargs"
|
733 |
+
|
734 |
+
proc_type = kwargs.get('proc_type', 'SUBPROCESS')
|
735 |
+
if using_forkdatacontext:
|
736 |
+
assert proc_type == "SUBPROCESS" or proc_type == "SUBPROCESS"
|
737 |
+
if proc_type == "NORMAL":
|
738 |
+
assert forkdatacontext_args_was_None, "if no fork, expect forkdatacontext.args None entering _traced_func"
|
739 |
+
assert forkdatacontext_kwargs_was_None, "if no fork, expect forkdatacontext.kwargs None entering _traced_func"
|
740 |
+
assert func is not None, "function should not be None, indicates original args[0] was None or args was None"
|
741 |
+
|
742 |
+
return func, args, kwargs
|
743 |
+
|
744 |
+
|
745 |
+
forkdatacontext = _ForkDataContext()
|
746 |
+
|
747 |
+
|
748 |
+
def _traced_func(func, *args, **kwargs):
|
749 |
+
func, args, kwargs = forkdatacontext.get_args_kwargs_for_traced_func(func, args, kwargs)
|
750 |
+
return func(*args, **kwargs)
|
751 |
+
|
752 |
+
|
753 |
+
def call_subprocess_onetask(func, args=None, kwargs=None):
|
754 |
+
if isinstance(args, list):
|
755 |
+
args = tuple(args)
|
756 |
+
if args is None:
|
757 |
+
args = ()
|
758 |
+
if kwargs is None:
|
759 |
+
kwargs = {}
|
760 |
+
args = list(args)
|
761 |
+
args = [func] + args
|
762 |
+
args = tuple(args)
|
763 |
+
with ForkContext(args=args, kwargs=kwargs):
|
764 |
+
args = (None,)
|
765 |
+
kwargs = {}
|
766 |
+
with ProcessPoolExecutor(max_workers=1) as executor:
|
767 |
+
future = executor.submit(_traced_func, *args, **kwargs)
|
768 |
+
return future.result()
|