Spaces:
Running
Running
pseudotensor
commited on
Commit
•
1ec3d3a
1
Parent(s):
30e5d19
Update with h2oGPT hash 221daabcabfa7f54b732394c15934a347da01079
Browse files- client_test.py +166 -9
- create_data.py +2 -2
- enums.py +84 -0
- finetune.py +32 -26
- generate.py +0 -0
- gpt4all_llm.py +66 -24
- gpt_langchain.py +1071 -232
- gradio_runner.py +0 -0
- gradio_themes.py +46 -6
- gradio_utils/__pycache__/css.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/grclient.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/prompt_form.cpython-310.pyc +0 -0
- gradio_utils/css.py +53 -0
- gradio_utils/grclient.py +82 -0
- gradio_utils/prompt_form.py +118 -0
- h2oai_pipeline.py +77 -8
- loaders.py +5 -2
- prompter.py +285 -107
- requirements.txt +79 -26
- stopping.py +10 -4
- utils.py +115 -30
- utils_langchain.py +64 -0
client_test.py
CHANGED
@@ -48,6 +48,8 @@ import markdown # pip install markdown
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
|
|
|
|
51 |
debug = False
|
52 |
|
53 |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
@@ -62,7 +64,10 @@ def get_client(serialize=True):
|
|
62 |
return client
|
63 |
|
64 |
|
65 |
-
def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
|
|
|
|
|
66 |
from collections import OrderedDict
|
67 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
68 |
iinput='', # only for chat=True
|
@@ -71,6 +76,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
|
|
71 |
# but leave stream_output=False for simple input/output mode
|
72 |
stream_output=stream_output,
|
73 |
prompt_type=prompt_type,
|
|
|
74 |
temperature=0.1,
|
75 |
top_p=0.75,
|
76 |
top_k=40,
|
@@ -86,9 +92,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
|
|
86 |
instruction_nochat=prompt if not chat else '',
|
87 |
iinput_nochat='', # only for chat=False
|
88 |
langchain_mode=langchain_mode,
|
89 |
-
top_k_docs=
|
90 |
-
|
|
|
|
|
91 |
)
|
|
|
|
|
92 |
if chat:
|
93 |
# add chatbot output on end. Assumes serialize=False
|
94 |
kwargs.update(dict(chatbot=[]))
|
@@ -97,8 +107,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
|
|
97 |
|
98 |
|
99 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
100 |
-
def test_client_basic():
|
101 |
-
return run_client_nochat(prompt='Who are you?', prompt_type=
|
102 |
|
103 |
|
104 |
def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
@@ -112,15 +122,110 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
|
112 |
)
|
113 |
print("Raw client result: %s" % res, flush=True)
|
114 |
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
response=md_to_text(ast.literal_eval(res)['response']),
|
116 |
sources=ast.literal_eval(res)['sources'])
|
117 |
print(res_dict)
|
118 |
-
return res_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
|
121 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
122 |
-
def
|
123 |
-
return run_client_chat(prompt=
|
|
|
124 |
langchain_mode='Disabled')
|
125 |
|
126 |
|
@@ -133,6 +238,7 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchai
|
|
133 |
|
134 |
|
135 |
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
|
|
136 |
res = client.predict(*tuple(args), api_name='/instruction')
|
137 |
args[-1] += [res[-1]]
|
138 |
|
@@ -166,6 +272,46 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
|
166 |
return res_dict, client
|
167 |
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
def md_to_text(md, do_md_to_text=True):
|
170 |
if not do_md_to_text:
|
171 |
return md
|
@@ -175,5 +321,16 @@ def md_to_text(md, do_md_to_text=True):
|
|
175 |
return soup.get_text()
|
176 |
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
if __name__ == '__main__':
|
179 |
-
|
|
|
48 |
import pytest
|
49 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
|
51 |
+
from enums import DocumentChoices
|
52 |
+
|
53 |
debug = False
|
54 |
|
55 |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
|
|
64 |
return client
|
65 |
|
66 |
|
67 |
+
def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
68 |
+
max_new_tokens=50,
|
69 |
+
top_k_docs=3,
|
70 |
+
langchain_mode='Disabled'):
|
71 |
from collections import OrderedDict
|
72 |
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
73 |
iinput='', # only for chat=True
|
|
|
76 |
# but leave stream_output=False for simple input/output mode
|
77 |
stream_output=stream_output,
|
78 |
prompt_type=prompt_type,
|
79 |
+
prompt_dict='',
|
80 |
temperature=0.1,
|
81 |
top_p=0.75,
|
82 |
top_k=40,
|
|
|
92 |
instruction_nochat=prompt if not chat else '',
|
93 |
iinput_nochat='', # only for chat=False
|
94 |
langchain_mode=langchain_mode,
|
95 |
+
top_k_docs=top_k_docs,
|
96 |
+
chunk=True,
|
97 |
+
chunk_size=512,
|
98 |
+
document_choice=[DocumentChoices.All_Relevant.name],
|
99 |
)
|
100 |
+
from generate import eval_func_param_names
|
101 |
+
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
102 |
if chat:
|
103 |
# add chatbot output on end. Assumes serialize=False
|
104 |
kwargs.update(dict(chatbot=[]))
|
|
|
107 |
|
108 |
|
109 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
110 |
+
def test_client_basic(prompt_type='human_bot'):
|
111 |
+
return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
112 |
|
113 |
|
114 |
def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
|
|
122 |
)
|
123 |
print("Raw client result: %s" % res, flush=True)
|
124 |
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
125 |
+
response=md_to_text(res))
|
126 |
+
print(res_dict)
|
127 |
+
return res_dict, client
|
128 |
+
|
129 |
+
|
130 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
131 |
+
def test_client_basic_api(prompt_type='human_bot'):
|
132 |
+
return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
133 |
+
|
134 |
+
|
135 |
+
def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
|
136 |
+
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
|
137 |
+
|
138 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
139 |
+
client = get_client(serialize=True)
|
140 |
+
res = client.predict(
|
141 |
+
str(dict(kwargs)),
|
142 |
+
api_name=api_name,
|
143 |
+
)
|
144 |
+
print("Raw client result: %s" % res, flush=True)
|
145 |
+
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
146 |
+
response=md_to_text(ast.literal_eval(res)['response']),
|
147 |
+
sources=ast.literal_eval(res)['sources'])
|
148 |
+
print(res_dict)
|
149 |
+
return res_dict, client
|
150 |
+
|
151 |
+
|
152 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
153 |
+
def test_client_basic_api_lean(prompt_type='human_bot'):
|
154 |
+
return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
155 |
+
|
156 |
+
|
157 |
+
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
|
158 |
+
kwargs = dict(instruction_nochat=prompt)
|
159 |
+
|
160 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
161 |
+
client = get_client(serialize=True)
|
162 |
+
res = client.predict(
|
163 |
+
str(dict(kwargs)),
|
164 |
+
api_name=api_name,
|
165 |
+
)
|
166 |
+
print("Raw client result: %s" % res, flush=True)
|
167 |
+
res_dict = dict(prompt=kwargs['instruction_nochat'],
|
168 |
+
response=md_to_text(ast.literal_eval(res)['response']),
|
169 |
+
sources=ast.literal_eval(res)['sources'])
|
170 |
+
print(res_dict)
|
171 |
+
return res_dict, client
|
172 |
+
|
173 |
+
|
174 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
175 |
+
def test_client_basic_api_lean_morestuff(prompt_type='human_bot'):
|
176 |
+
return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
177 |
+
|
178 |
+
|
179 |
+
def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512):
|
180 |
+
kwargs = dict(
|
181 |
+
instruction='',
|
182 |
+
iinput='',
|
183 |
+
context='',
|
184 |
+
stream_output=False,
|
185 |
+
prompt_type=prompt_type,
|
186 |
+
temperature=0.1,
|
187 |
+
top_p=0.75,
|
188 |
+
top_k=40,
|
189 |
+
num_beams=1,
|
190 |
+
max_new_tokens=256,
|
191 |
+
min_new_tokens=0,
|
192 |
+
early_stopping=False,
|
193 |
+
max_time=20,
|
194 |
+
repetition_penalty=1.0,
|
195 |
+
num_return_sequences=1,
|
196 |
+
do_sample=True,
|
197 |
+
chat=False,
|
198 |
+
instruction_nochat=prompt,
|
199 |
+
iinput_nochat='',
|
200 |
+
langchain_mode='Disabled',
|
201 |
+
top_k_docs=4,
|
202 |
+
document_choice=['All'],
|
203 |
+
)
|
204 |
+
|
205 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
206 |
+
client = get_client(serialize=True)
|
207 |
+
res = client.predict(
|
208 |
+
str(dict(kwargs)),
|
209 |
+
api_name=api_name,
|
210 |
+
)
|
211 |
+
print("Raw client result: %s" % res, flush=True)
|
212 |
+
res_dict = dict(prompt=kwargs['instruction_nochat'],
|
213 |
response=md_to_text(ast.literal_eval(res)['response']),
|
214 |
sources=ast.literal_eval(res)['sources'])
|
215 |
print(res_dict)
|
216 |
+
return res_dict, client
|
217 |
+
|
218 |
+
|
219 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
220 |
+
def test_client_chat(prompt_type='human_bot'):
|
221 |
+
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
222 |
+
langchain_mode='Disabled')
|
223 |
|
224 |
|
225 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
226 |
+
def test_client_chat_stream(prompt_type='human_bot'):
|
227 |
+
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
228 |
+
stream_output=True, max_new_tokens=512,
|
229 |
langchain_mode='Disabled')
|
230 |
|
231 |
|
|
|
238 |
|
239 |
|
240 |
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
241 |
+
assert kwargs['chat'], "Chat mode only"
|
242 |
res = client.predict(*tuple(args), api_name='/instruction')
|
243 |
args[-1] += [res[-1]]
|
244 |
|
|
|
272 |
return res_dict, client
|
273 |
|
274 |
|
275 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
276 |
+
def test_client_nochat_stream(prompt_type='human_bot'):
|
277 |
+
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
278 |
+
stream_output=True, max_new_tokens=512,
|
279 |
+
langchain_mode='Disabled')
|
280 |
+
|
281 |
+
|
282 |
+
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
|
283 |
+
client = get_client(serialize=False)
|
284 |
+
|
285 |
+
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
286 |
+
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
|
287 |
+
return run_client_gen(client, prompt, args, kwargs)
|
288 |
+
|
289 |
+
|
290 |
+
def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
291 |
+
res_dict = kwargs
|
292 |
+
res_dict['prompt'] = prompt
|
293 |
+
if not kwargs['stream_output']:
|
294 |
+
res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
|
295 |
+
res_dict['response'] = res[0]
|
296 |
+
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
|
297 |
+
return res_dict, client
|
298 |
+
else:
|
299 |
+
job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api')
|
300 |
+
while not job.done():
|
301 |
+
outputs_list = job.communicator.job.outputs
|
302 |
+
if outputs_list:
|
303 |
+
res = job.communicator.job.outputs[-1]
|
304 |
+
res_dict = ast.literal_eval(res)
|
305 |
+
print('Stream: %s' % res_dict['response'])
|
306 |
+
time.sleep(0.1)
|
307 |
+
res_list = job.outputs()
|
308 |
+
assert len(res_list) > 0, "No response, check server"
|
309 |
+
res = res_list[-1]
|
310 |
+
res_dict = ast.literal_eval(res)
|
311 |
+
print('Final: %s' % res_dict['response'])
|
312 |
+
return res_dict, client
|
313 |
+
|
314 |
+
|
315 |
def md_to_text(md, do_md_to_text=True):
|
316 |
if not do_md_to_text:
|
317 |
return md
|
|
|
321 |
return soup.get_text()
|
322 |
|
323 |
|
324 |
+
def run_client_many(prompt_type='human_bot'):
|
325 |
+
ret1, _ = test_client_chat(prompt_type=prompt_type)
|
326 |
+
ret2, _ = test_client_chat_stream(prompt_type=prompt_type)
|
327 |
+
ret3, _ = test_client_nochat_stream(prompt_type=prompt_type)
|
328 |
+
ret4, _ = test_client_basic(prompt_type=prompt_type)
|
329 |
+
ret5, _ = test_client_basic_api(prompt_type=prompt_type)
|
330 |
+
ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type)
|
331 |
+
ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type)
|
332 |
+
return ret1, ret2, ret3, ret4, ret5, ret6, ret7
|
333 |
+
|
334 |
+
|
335 |
if __name__ == '__main__':
|
336 |
+
run_client_many()
|
create_data.py
CHANGED
@@ -567,7 +567,7 @@ def test_show_prompts():
|
|
567 |
from prompter import generate_prompt
|
568 |
for data_points in file_points:
|
569 |
for data_point in data_points:
|
570 |
-
print(generate_prompt(data_point, 'plain', False, False)[0])
|
571 |
|
572 |
|
573 |
def test_get_open_datasets():
|
@@ -1571,7 +1571,7 @@ def test_check_stats_data():
|
|
1571 |
|
1572 |
llama_type = False
|
1573 |
tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
1574 |
-
model_loader, tokenizer_loader = get_loaders(
|
1575 |
local_files_only = False
|
1576 |
resume_download = True
|
1577 |
use_auth_token = False
|
|
|
567 |
from prompter import generate_prompt
|
568 |
for data_points in file_points:
|
569 |
for data_point in data_points:
|
570 |
+
print(generate_prompt(data_point, 'plain', '', False, False, False)[0])
|
571 |
|
572 |
|
573 |
def test_get_open_datasets():
|
|
|
1571 |
|
1572 |
llama_type = False
|
1573 |
tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
1574 |
+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
|
1575 |
local_files_only = False
|
1576 |
resume_download = True
|
1577 |
use_auth_token = False
|
enums.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class PromptType(Enum):
|
5 |
+
custom = -1
|
6 |
+
plain = 0
|
7 |
+
instruct = 1
|
8 |
+
quality = 2
|
9 |
+
human_bot = 3
|
10 |
+
dai_faq = 4
|
11 |
+
summarize = 5
|
12 |
+
simple_instruct = 6
|
13 |
+
instruct_vicuna = 7
|
14 |
+
instruct_with_end = 8
|
15 |
+
human_bot_orig = 9
|
16 |
+
prompt_answer = 10
|
17 |
+
open_assistant = 11
|
18 |
+
wizard_lm = 12
|
19 |
+
wizard_mega = 13
|
20 |
+
instruct_vicuna2 = 14
|
21 |
+
instruct_vicuna3 = 15
|
22 |
+
wizard2 = 16
|
23 |
+
wizard3 = 17
|
24 |
+
instruct_simple = 18
|
25 |
+
wizard_vicuna = 19
|
26 |
+
openai = 20
|
27 |
+
openai_chat = 21
|
28 |
+
gptj = 22
|
29 |
+
prompt_answer_openllama = 23
|
30 |
+
vicuna11 = 24
|
31 |
+
|
32 |
+
|
33 |
+
class DocumentChoices(Enum):
|
34 |
+
All_Relevant = 0
|
35 |
+
All_Relevant_Only_Sources = 1
|
36 |
+
Only_All_Sources = 2
|
37 |
+
Just_LLM = 3
|
38 |
+
|
39 |
+
|
40 |
+
class LangChainMode(Enum):
|
41 |
+
"""LangChain mode"""
|
42 |
+
|
43 |
+
DISABLED = "Disabled"
|
44 |
+
CHAT_LLM = "ChatLLM"
|
45 |
+
LLM = "LLM"
|
46 |
+
ALL = "All"
|
47 |
+
WIKI = "wiki"
|
48 |
+
WIKI_FULL = "wiki_full"
|
49 |
+
USER_DATA = "UserData"
|
50 |
+
MY_DATA = "MyData"
|
51 |
+
GITHUB_H2OGPT = "github h2oGPT"
|
52 |
+
H2O_DAI_DOCS = "DriverlessAI docs"
|
53 |
+
|
54 |
+
|
55 |
+
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
56 |
+
|
57 |
+
|
58 |
+
# from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information
|
59 |
+
model_token_mapping = {
|
60 |
+
"gpt-4": 8192,
|
61 |
+
"gpt-4-0314": 8192,
|
62 |
+
"gpt-4-32k": 32768,
|
63 |
+
"gpt-4-32k-0314": 32768,
|
64 |
+
"gpt-3.5-turbo": 4096,
|
65 |
+
"gpt-3.5-turbo-16k": 16*1024,
|
66 |
+
"gpt-3.5-turbo-0301": 4096,
|
67 |
+
"text-ada-001": 2049,
|
68 |
+
"ada": 2049,
|
69 |
+
"text-babbage-001": 2040,
|
70 |
+
"babbage": 2049,
|
71 |
+
"text-curie-001": 2049,
|
72 |
+
"curie": 2049,
|
73 |
+
"davinci": 2049,
|
74 |
+
"text-davinci-003": 4097,
|
75 |
+
"text-davinci-002": 4097,
|
76 |
+
"code-davinci-002": 8001,
|
77 |
+
"code-davinci-001": 8001,
|
78 |
+
"code-cushman-002": 2048,
|
79 |
+
"code-cushman-001": 2048,
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
source_prefix = "Sources [Score | Link]:"
|
84 |
+
source_postfix = "End Sources<p>"
|
finetune.py
CHANGED
@@ -5,8 +5,11 @@ from typing import List, Union
|
|
5 |
import fire
|
6 |
import numpy as np
|
7 |
|
|
|
|
|
|
|
8 |
from loaders import get_loaders, get_tokenizer
|
9 |
-
from prompter import generate_prompt, prompt_types
|
10 |
from utils import get_githash, copy_code
|
11 |
import torch
|
12 |
|
@@ -104,7 +107,6 @@ def train(
|
|
104 |
save_total_limit: int = 3,
|
105 |
add_eos_token: bool = False,
|
106 |
):
|
107 |
-
|
108 |
if llama_flash_attn:
|
109 |
# Need to call this before importing transformers.
|
110 |
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
@@ -129,10 +131,12 @@ def train(
|
|
129 |
if not output_dir:
|
130 |
output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
|
131 |
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
132 |
-
raise FileExistsError(
|
|
|
133 |
else:
|
134 |
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
135 |
-
raise FileExistsError(
|
|
|
136 |
device_map = "auto"
|
137 |
|
138 |
if save_code:
|
@@ -181,7 +185,7 @@ def train(
|
|
181 |
log("num_gpus: %d" % gpus)
|
182 |
log("max mem: %s" % max_memory)
|
183 |
|
184 |
-
model_loader, tokenizer_loader = get_loaders(
|
185 |
|
186 |
model = model_loader.from_pretrained(
|
187 |
base_model,
|
@@ -398,7 +402,8 @@ def train(
|
|
398 |
if train_data_mix_in:
|
399 |
train_data = concatenate_datasets([train_data, train_data_mix_in])
|
400 |
log("Tokenizing %s training rows" % train_data.num_rows)
|
401 |
-
train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun,
|
|
|
402 |
if drop_truncations:
|
403 |
log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
|
404 |
prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
|
@@ -413,7 +418,8 @@ def train(
|
|
413 |
|
414 |
if valid_data:
|
415 |
log("Tokenizing %s validation rows" % valid_data.num_rows)
|
416 |
-
valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun,
|
|
|
417 |
val_set_size = len(valid_data)
|
418 |
else:
|
419 |
val_set_size = 0
|
@@ -468,7 +474,7 @@ def train(
|
|
468 |
elif save_steps > eval_steps:
|
469 |
# save steps must be round multiple of eval_steps
|
470 |
save_steps0 = save_steps
|
471 |
-
save_steps = max(1, (save_steps//eval_steps)) * eval_steps
|
472 |
if save_steps0 != save_steps:
|
473 |
log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
|
474 |
|
@@ -478,21 +484,21 @@ def train(
|
|
478 |
label_ids = eval_preds.label_ids
|
479 |
predictions = eval_preds.predictions
|
480 |
|
481 |
-
#inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
|
482 |
-
#decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
|
483 |
-
#decoded_inputs = [pred.strip() for pred in decoded_inputs]
|
484 |
|
485 |
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
486 |
# tokenizer behavior like generate time
|
487 |
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
|
488 |
-
|
489 |
decoded_labels = [pred.strip() for pred in decoded_labels]
|
490 |
|
491 |
predictions = np.argmax(predictions, -1)
|
492 |
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
|
493 |
# tokenizer behavior like generate time
|
494 |
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
|
495 |
-
|
496 |
decoded_predictions = [pred.strip() for pred in decoded_predictions]
|
497 |
|
498 |
result = {}
|
@@ -541,8 +547,8 @@ def train(
|
|
541 |
load_best_model_at_end=True if val_set_size > 0 else False,
|
542 |
ddp_find_unused_parameters=False if ddp else None,
|
543 |
group_by_length=group_by_length,
|
544 |
-
#fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
|
545 |
-
#fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
|
546 |
report_to='tensorboard' if not neptune_run else 'neptune',
|
547 |
),
|
548 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
@@ -553,13 +559,6 @@ def train(
|
|
553 |
)
|
554 |
model.config.use_cache = False
|
555 |
|
556 |
-
old_state_dict = model.state_dict
|
557 |
-
from peft import get_peft_model_state_dict
|
558 |
-
|
559 |
-
model.state_dict = (
|
560 |
-
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
561 |
-
).__get__(model, type(model))
|
562 |
-
|
563 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
564 |
model = torch.compile(model)
|
565 |
# WIP (not generally replacing layers until pytorch 2.1)
|
@@ -616,10 +615,12 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
|
|
616 |
assert prompt_type is not None
|
617 |
assert cutoff_len is not None
|
618 |
assert tokenizer is not None
|
619 |
-
|
|
|
|
|
620 |
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
621 |
if not train_on_inputs:
|
622 |
-
user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
|
623 |
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
624 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
625 |
if add_eos_token:
|
@@ -638,7 +639,7 @@ def test_debug():
|
|
638 |
fire.Fire(train)
|
639 |
|
640 |
|
641 |
-
|
642 |
CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
|
643 |
CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
|
644 |
log(f"""
|
@@ -665,6 +666,11 @@ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank
|
|
665 |
|
666 |
if os.environ.get("LOCAL_RANK") is None:
|
667 |
# then not using torchrun, so can't do distributed, ensure CVD set
|
668 |
-
assert os.environ.get(
|
|
|
669 |
|
670 |
fire.Fire(train)
|
|
|
|
|
|
|
|
|
|
5 |
import fire
|
6 |
import numpy as np
|
7 |
|
8 |
+
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
9 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
10 |
+
|
11 |
from loaders import get_loaders, get_tokenizer
|
12 |
+
from prompter import generate_prompt, prompt_types, PromptType
|
13 |
from utils import get_githash, copy_code
|
14 |
import torch
|
15 |
|
|
|
107 |
save_total_limit: int = 3,
|
108 |
add_eos_token: bool = False,
|
109 |
):
|
|
|
110 |
if llama_flash_attn:
|
111 |
# Need to call this before importing transformers.
|
112 |
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
|
|
131 |
if not output_dir:
|
132 |
output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
|
133 |
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
134 |
+
raise FileExistsError(
|
135 |
+
f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
|
136 |
else:
|
137 |
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
138 |
+
raise FileExistsError(
|
139 |
+
f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
|
140 |
device_map = "auto"
|
141 |
|
142 |
if save_code:
|
|
|
185 |
log("num_gpus: %d" % gpus)
|
186 |
log("max mem: %s" % max_memory)
|
187 |
|
188 |
+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
|
189 |
|
190 |
model = model_loader.from_pretrained(
|
191 |
base_model,
|
|
|
402 |
if train_data_mix_in:
|
403 |
train_data = concatenate_datasets([train_data, train_data_mix_in])
|
404 |
log("Tokenizing %s training rows" % train_data.num_rows)
|
405 |
+
train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun,
|
406 |
+
num_proc=os.cpu_count() // torch.cuda.device_count())
|
407 |
if drop_truncations:
|
408 |
log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
|
409 |
prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
|
|
|
418 |
|
419 |
if valid_data:
|
420 |
log("Tokenizing %s validation rows" % valid_data.num_rows)
|
421 |
+
valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun,
|
422 |
+
num_proc=os.cpu_count() // torch.cuda.device_count())
|
423 |
val_set_size = len(valid_data)
|
424 |
else:
|
425 |
val_set_size = 0
|
|
|
474 |
elif save_steps > eval_steps:
|
475 |
# save steps must be round multiple of eval_steps
|
476 |
save_steps0 = save_steps
|
477 |
+
save_steps = max(1, (save_steps // eval_steps)) * eval_steps
|
478 |
if save_steps0 != save_steps:
|
479 |
log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
|
480 |
|
|
|
484 |
label_ids = eval_preds.label_ids
|
485 |
predictions = eval_preds.predictions
|
486 |
|
487 |
+
# inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
|
488 |
+
# decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
|
489 |
+
# decoded_inputs = [pred.strip() for pred in decoded_inputs]
|
490 |
|
491 |
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
492 |
# tokenizer behavior like generate time
|
493 |
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
|
494 |
+
clean_up_tokenization_spaces=True)
|
495 |
decoded_labels = [pred.strip() for pred in decoded_labels]
|
496 |
|
497 |
predictions = np.argmax(predictions, -1)
|
498 |
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
|
499 |
# tokenizer behavior like generate time
|
500 |
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
|
501 |
+
clean_up_tokenization_spaces=True)
|
502 |
decoded_predictions = [pred.strip() for pred in decoded_predictions]
|
503 |
|
504 |
result = {}
|
|
|
547 |
load_best_model_at_end=True if val_set_size > 0 else False,
|
548 |
ddp_find_unused_parameters=False if ddp else None,
|
549 |
group_by_length=group_by_length,
|
550 |
+
# fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
|
551 |
+
# fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
|
552 |
report_to='tensorboard' if not neptune_run else 'neptune',
|
553 |
),
|
554 |
data_collator=transformers.DataCollatorForSeq2Seq(
|
|
|
559 |
)
|
560 |
model.config.use_cache = False
|
561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
563 |
model = torch.compile(model)
|
564 |
# WIP (not generally replacing layers until pytorch 2.1)
|
|
|
615 |
assert prompt_type is not None
|
616 |
assert cutoff_len is not None
|
617 |
assert tokenizer is not None
|
618 |
+
prompt_dict = '' # only for custom prompt_type
|
619 |
+
assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
|
620 |
+
full_prompt, _, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False, False)
|
621 |
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
622 |
if not train_on_inputs:
|
623 |
+
user_prompt, _, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False, False)
|
624 |
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
625 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
626 |
if add_eos_token:
|
|
|
639 |
fire.Fire(train)
|
640 |
|
641 |
|
642 |
+
def entrypoint_main():
|
643 |
CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
|
644 |
CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
|
645 |
log(f"""
|
|
|
666 |
|
667 |
if os.environ.get("LOCAL_RANK") is None:
|
668 |
# then not using torchrun, so can't do distributed, ensure CVD set
|
669 |
+
assert os.environ.get(
|
670 |
+
"CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
|
671 |
|
672 |
fire.Fire(train)
|
673 |
+
|
674 |
+
|
675 |
+
if __name__ == "__main__":
|
676 |
+
entrypoint_main()
|
generate.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
gpt4all_llm.py
CHANGED
@@ -1,23 +1,13 @@
|
|
1 |
import inspect
|
2 |
import os
|
3 |
-
import
|
4 |
from typing import Dict, Any, Optional, List
|
5 |
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
6 |
from pydantic import root_validator
|
7 |
from langchain.llms import gpt4all
|
8 |
from dotenv import dotenv_values
|
9 |
|
10 |
-
|
11 |
-
class FakeTokenizer:
|
12 |
-
|
13 |
-
def encode(self, x, *args, **kwargs):
|
14 |
-
return dict(input_ids=[x])
|
15 |
-
|
16 |
-
def decode(self, x, *args, **kwargs):
|
17 |
-
return x
|
18 |
-
|
19 |
-
def __call__(self, x, *args, **kwargs):
|
20 |
-
return self.encode(x, *args, **kwargs)
|
21 |
|
22 |
|
23 |
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
@@ -73,9 +63,9 @@ class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
73 |
pass
|
74 |
|
75 |
|
76 |
-
def get_model_kwargs(env_kwargs, default_kwargs, cls):
|
77 |
# default from class
|
78 |
-
model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
|
79 |
# from our defaults
|
80 |
model_kwargs.update(default_kwargs)
|
81 |
# from user defaults
|
@@ -93,10 +83,14 @@ def get_llm_gpt4all(model_name,
|
|
93 |
repetition_penalty=1.0,
|
94 |
top_k=40,
|
95 |
top_p=0.7,
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
env_gpt4all_file = ".env_gpt4all"
|
98 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
99 |
-
callbacks = [H2OStreamingStdOutCallbackHandler()]
|
100 |
n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
|
101 |
default_kwargs = dict(context_erase=0.5,
|
102 |
n_batch=1,
|
@@ -113,21 +107,23 @@ def get_llm_gpt4all(model_name,
|
|
113 |
if model_name == 'llama':
|
114 |
cls = H2OLlamaCpp
|
115 |
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
116 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
117 |
-
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
|
118 |
llm = cls(**model_kwargs)
|
119 |
llm.client.verbose = verbose
|
120 |
elif model_name == 'gpt4all_llama':
|
121 |
cls = H2OGPT4All
|
122 |
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
123 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
124 |
-
model_kwargs.update(
|
|
|
125 |
llm = cls(**model_kwargs)
|
126 |
elif model_name == 'gptj':
|
127 |
cls = H2OGPT4All
|
128 |
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
129 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
130 |
-
model_kwargs.update(
|
|
|
131 |
llm = cls(**model_kwargs)
|
132 |
else:
|
133 |
raise RuntimeError("No such model_name %s" % model_name)
|
@@ -136,6 +132,7 @@ def get_llm_gpt4all(model_name,
|
|
136 |
|
137 |
class H2OGPT4All(gpt4all.GPT4All):
|
138 |
model: Any
|
|
|
139 |
"""Path to the pre-trained GPT4All model file."""
|
140 |
|
141 |
@root_validator()
|
@@ -155,9 +152,16 @@ class H2OGPT4All(gpt4all.GPT4All):
|
|
155 |
model_type=values["backend"],
|
156 |
allow_download=False,
|
157 |
)
|
|
|
|
|
|
|
158 |
else:
|
159 |
values["client"] = values["model"]
|
160 |
-
|
|
|
|
|
|
|
|
|
161 |
|
162 |
except ImportError:
|
163 |
raise ValueError(
|
@@ -171,12 +175,19 @@ class H2OGPT4All(gpt4all.GPT4All):
|
|
171 |
prompt: str,
|
172 |
stop: Optional[List[str]] = None,
|
173 |
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
174 |
) -> str:
|
175 |
# Roughly 4 chars per token if natural language
|
176 |
prompt = prompt[-self.n_ctx * 4:]
|
|
|
|
|
|
|
|
|
|
|
177 |
verbose = False
|
178 |
if verbose:
|
179 |
print("_call prompt: %s" % prompt, flush=True)
|
|
|
180 |
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
181 |
|
182 |
|
@@ -185,6 +196,7 @@ from langchain.llms import LlamaCpp
|
|
185 |
|
186 |
class H2OLlamaCpp(LlamaCpp):
|
187 |
model_path: Any
|
|
|
188 |
"""Path to the pre-trained GPT4All model file."""
|
189 |
|
190 |
@root_validator()
|
@@ -236,9 +248,12 @@ class H2OLlamaCpp(LlamaCpp):
|
|
236 |
prompt: str,
|
237 |
stop: Optional[List[str]] = None,
|
238 |
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
239 |
) -> str:
|
240 |
verbose = False
|
241 |
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
|
|
|
|
242 |
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
243 |
num_prompt_tokens = len(prompt_tokens)
|
244 |
if num_prompt_tokens > self.n_ctx:
|
@@ -250,6 +265,33 @@ class H2OLlamaCpp(LlamaCpp):
|
|
250 |
prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
251 |
num_prompt_tokens2 = len(prompt_tokens2)
|
252 |
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
|
|
|
|
|
|
|
|
|
|
253 |
if verbose:
|
254 |
print("_call prompt: %s" % prompt, flush=True)
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import inspect
|
2 |
import os
|
3 |
+
from functools import partial
|
4 |
from typing import Dict, Any, Optional, List
|
5 |
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
6 |
from pydantic import root_validator
|
7 |
from langchain.llms import gpt4all
|
8 |
from dotenv import dotenv_values
|
9 |
|
10 |
+
from utils import FakeTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
|
|
63 |
pass
|
64 |
|
65 |
|
66 |
+
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
|
67 |
# default from class
|
68 |
+
model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
|
69 |
# from our defaults
|
70 |
model_kwargs.update(default_kwargs)
|
71 |
# from user defaults
|
|
|
83 |
repetition_penalty=1.0,
|
84 |
top_k=40,
|
85 |
top_p=0.7,
|
86 |
+
streaming=False,
|
87 |
+
callbacks=None,
|
88 |
+
prompter=None,
|
89 |
+
verbose=False,
|
90 |
+
):
|
91 |
+
assert prompter is not None
|
92 |
env_gpt4all_file = ".env_gpt4all"
|
93 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
|
|
94 |
n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
|
95 |
default_kwargs = dict(context_erase=0.5,
|
96 |
n_batch=1,
|
|
|
107 |
if model_name == 'llama':
|
108 |
cls = H2OLlamaCpp
|
109 |
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
110 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
111 |
+
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming, prompter=prompter))
|
112 |
llm = cls(**model_kwargs)
|
113 |
llm.client.verbose = verbose
|
114 |
elif model_name == 'gpt4all_llama':
|
115 |
cls = H2OGPT4All
|
116 |
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
117 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
118 |
+
model_kwargs.update(
|
119 |
+
dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming, prompter=prompter))
|
120 |
llm = cls(**model_kwargs)
|
121 |
elif model_name == 'gptj':
|
122 |
cls = H2OGPT4All
|
123 |
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
124 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
125 |
+
model_kwargs.update(
|
126 |
+
dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming, prompter=prompter))
|
127 |
llm = cls(**model_kwargs)
|
128 |
else:
|
129 |
raise RuntimeError("No such model_name %s" % model_name)
|
|
|
132 |
|
133 |
class H2OGPT4All(gpt4all.GPT4All):
|
134 |
model: Any
|
135 |
+
prompter: Any
|
136 |
"""Path to the pre-trained GPT4All model file."""
|
137 |
|
138 |
@root_validator()
|
|
|
152 |
model_type=values["backend"],
|
153 |
allow_download=False,
|
154 |
)
|
155 |
+
if values["n_threads"] is not None:
|
156 |
+
# set n_threads
|
157 |
+
values["client"].model.set_thread_count(values["n_threads"])
|
158 |
else:
|
159 |
values["client"] = values["model"]
|
160 |
+
try:
|
161 |
+
values["backend"] = values["client"].model_type
|
162 |
+
except AttributeError:
|
163 |
+
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
|
164 |
+
values["backend"] = values["client"].model.model_type
|
165 |
|
166 |
except ImportError:
|
167 |
raise ValueError(
|
|
|
175 |
prompt: str,
|
176 |
stop: Optional[List[str]] = None,
|
177 |
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
178 |
+
**kwargs,
|
179 |
) -> str:
|
180 |
# Roughly 4 chars per token if natural language
|
181 |
prompt = prompt[-self.n_ctx * 4:]
|
182 |
+
|
183 |
+
# use instruct prompting
|
184 |
+
data_point = dict(context='', instruction=prompt, input='')
|
185 |
+
prompt = self.prompter.generate_prompt(data_point)
|
186 |
+
|
187 |
verbose = False
|
188 |
if verbose:
|
189 |
print("_call prompt: %s" % prompt, flush=True)
|
190 |
+
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
|
191 |
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
192 |
|
193 |
|
|
|
196 |
|
197 |
class H2OLlamaCpp(LlamaCpp):
|
198 |
model_path: Any
|
199 |
+
prompter: Any
|
200 |
"""Path to the pre-trained GPT4All model file."""
|
201 |
|
202 |
@root_validator()
|
|
|
248 |
prompt: str,
|
249 |
stop: Optional[List[str]] = None,
|
250 |
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
251 |
+
**kwargs,
|
252 |
) -> str:
|
253 |
verbose = False
|
254 |
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
255 |
+
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
|
256 |
+
prompt = prompt[-self.n_ctx * 4:]
|
257 |
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
258 |
num_prompt_tokens = len(prompt_tokens)
|
259 |
if num_prompt_tokens > self.n_ctx:
|
|
|
265 |
prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
266 |
num_prompt_tokens2 = len(prompt_tokens2)
|
267 |
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
268 |
+
|
269 |
+
# use instruct prompting
|
270 |
+
data_point = dict(context='', instruction=prompt, input='')
|
271 |
+
prompt = self.prompter.generate_prompt(data_point)
|
272 |
+
|
273 |
if verbose:
|
274 |
print("_call prompt: %s" % prompt, flush=True)
|
275 |
+
|
276 |
+
if self.streaming:
|
277 |
+
text_callback = None
|
278 |
+
if run_manager:
|
279 |
+
text_callback = partial(
|
280 |
+
run_manager.on_llm_new_token, verbose=self.verbose
|
281 |
+
)
|
282 |
+
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
283 |
+
if text_callback:
|
284 |
+
text_callback(prompt)
|
285 |
+
text = ""
|
286 |
+
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
|
287 |
+
text_chunk = token["choices"][0]["text"]
|
288 |
+
# self.stream already calls text_callback
|
289 |
+
# if text_callback:
|
290 |
+
# text_callback(text_chunk)
|
291 |
+
text += text_chunk
|
292 |
+
return text
|
293 |
+
else:
|
294 |
+
params = self._get_parameters(stop)
|
295 |
+
params = {**params, **kwargs}
|
296 |
+
result = self.client(prompt=prompt, **params)
|
297 |
+
return result["choices"][0]["text"]
|
gpt_langchain.py
CHANGED
@@ -1,27 +1,34 @@
|
|
|
|
1 |
import glob
|
2 |
import inspect
|
3 |
import os
|
4 |
import pathlib
|
5 |
import pickle
|
6 |
-
import queue
|
7 |
import shutil
|
8 |
import subprocess
|
9 |
-
import sys
|
10 |
import tempfile
|
|
|
11 |
import traceback
|
|
|
12 |
import uuid
|
13 |
import zipfile
|
14 |
from collections import defaultdict
|
15 |
from datetime import datetime
|
16 |
from functools import reduce
|
17 |
from operator import concat
|
|
|
18 |
|
19 |
-
from joblib import
|
|
|
|
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
-
from
|
|
|
|
|
23 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
24 |
-
get_device, ProgressParallel, remove, hash_file
|
|
|
25 |
|
26 |
import_matplotlib()
|
27 |
|
@@ -36,19 +43,22 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
|
36 |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
37 |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
38 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
39 |
-
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
|
40 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
41 |
from langchain.chains.question_answering import load_qa_chain
|
42 |
from langchain.docstore.document import Document
|
43 |
-
from langchain import PromptTemplate
|
44 |
from langchain.vectorstores import Chroma
|
45 |
|
46 |
|
47 |
-
def get_db(sources, use_openai_embedding=False, db_type='faiss',
|
|
|
|
|
48 |
collection_name=None,
|
49 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
|
50 |
if not sources:
|
51 |
return None
|
|
|
52 |
# get embedding model
|
53 |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
54 |
assert collection_name is not None or langchain_mode != 'notset'
|
@@ -59,29 +69,41 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo
|
|
59 |
if db_type == 'faiss':
|
60 |
from langchain.vectorstores import FAISS
|
61 |
db = FAISS.from_documents(sources, embedding)
|
62 |
-
|
63 |
elif db_type == 'weaviate':
|
64 |
import weaviate
|
65 |
from weaviate.embedded import EmbeddedOptions
|
66 |
from langchain.vectorstores import Weaviate
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
72 |
index_name = collection_name.capitalize()
|
73 |
db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
|
74 |
index_name=index_name)
|
75 |
-
|
76 |
elif db_type == 'chroma':
|
77 |
assert persist_directory is not None
|
78 |
os.makedirs(persist_directory, exist_ok=True)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
else:
|
86 |
raise RuntimeError("No such db_type=%s" % db_type)
|
87 |
|
@@ -104,7 +126,10 @@ def _get_unique_sources_in_weaviate(db):
|
|
104 |
|
105 |
def add_to_db(db, sources, db_type='faiss',
|
106 |
avoid_dup_by_file=False,
|
107 |
-
avoid_dup_by_content=True
|
|
|
|
|
|
|
108 |
num_new_sources = len(sources)
|
109 |
if not sources:
|
110 |
return db, num_new_sources, []
|
@@ -120,7 +145,7 @@ def add_to_db(db, sources, db_type='faiss',
|
|
120 |
return db, num_new_sources, []
|
121 |
db.add_documents(documents=sources)
|
122 |
elif db_type == 'chroma':
|
123 |
-
collection = db
|
124 |
# files we already have:
|
125 |
metadata_files = set([x['source'] for x in collection['metadatas']])
|
126 |
if avoid_dup_by_file:
|
@@ -135,11 +160,15 @@ def add_to_db(db, sources, db_type='faiss',
|
|
135 |
[x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
|
136 |
# avoid sources with same hash
|
137 |
sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
|
|
|
|
|
|
|
138 |
# get new file names that match existing file names. delete existing files we are overridding
|
139 |
dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
|
140 |
print("Removing %s duplicate files from db because ingesting those as new documents" % len(
|
141 |
dup_metadata_files), flush=True)
|
142 |
-
client_collection = db._client.get_collection(name=db._collection.name
|
|
|
143 |
for dup_file in dup_metadata_files:
|
144 |
dup_file_meta = dict(source=dup_file)
|
145 |
try:
|
@@ -151,6 +180,8 @@ def add_to_db(db, sources, db_type='faiss',
|
|
151 |
return db, num_new_sources, []
|
152 |
db.add_documents(documents=sources)
|
153 |
db.persist()
|
|
|
|
|
154 |
else:
|
155 |
raise RuntimeError("No such db_type=%s" % db_type)
|
156 |
|
@@ -165,10 +196,13 @@ def create_or_update_db(db_type, persist_directory, collection_name,
|
|
165 |
import weaviate
|
166 |
from weaviate.embedded import EmbeddedOptions
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
172 |
index_name = collection_name.replace(' ', '_').capitalize()
|
173 |
if client.schema.exists(index_name) and not add_if_exists:
|
174 |
client.schema.delete_class(index_name)
|
@@ -205,14 +239,20 @@ def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformer
|
|
205 |
if use_openai_embedding:
|
206 |
assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
|
207 |
from langchain.embeddings import OpenAIEmbeddings
|
208 |
-
embedding = OpenAIEmbeddings()
|
209 |
else:
|
210 |
# to ensure can fork without deadlock
|
211 |
from langchain.embeddings import HuggingFaceEmbeddings
|
212 |
|
213 |
device, torch_dtype, context_class = get_device_dtype()
|
214 |
model_kwargs = dict(device=device)
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
return embedding
|
217 |
|
218 |
|
@@ -226,63 +266,481 @@ def get_answer_from_sources(chain, sources, question):
|
|
226 |
)["output_text"]
|
227 |
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
temperature=0.1,
|
233 |
-
repetition_penalty=1.0,
|
234 |
top_k=40,
|
235 |
top_p=0.7,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
prompt_type=None,
|
|
|
237 |
prompter=None,
|
|
|
238 |
verbose=False,
|
239 |
):
|
240 |
-
if use_openai_model:
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
elif model_name in non_hf_types:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
from gpt4all_llm import get_llm_gpt4all
|
248 |
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
|
249 |
temperature=temperature,
|
250 |
repetition_penalty=repetition_penalty,
|
251 |
top_k=top_k,
|
252 |
top_p=top_p,
|
|
|
253 |
verbose=verbose,
|
|
|
|
|
254 |
)
|
255 |
-
streamer = None
|
256 |
-
prompt_type = 'plain'
|
257 |
else:
|
258 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
259 |
-
|
260 |
if model is None:
|
261 |
# only used if didn't pass model in
|
262 |
-
assert model_name is None
|
263 |
assert tokenizer is None
|
264 |
prompt_type = 'human_bot'
|
265 |
-
model_name
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
load_8bit = True
|
273 |
-
# FIXME: for now not to spread across hetero GPUs
|
274 |
-
# device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
|
275 |
-
device_map = {"": 0} if device == 'cuda' else "auto"
|
276 |
-
model = AutoModelForCausalLM.from_pretrained(model_name,
|
277 |
-
device_map=device_map,
|
278 |
-
torch_dtype=torch_dtype,
|
279 |
-
load_in_8bit=load_8bit)
|
280 |
|
281 |
max_max_tokens = tokenizer.model_max_length
|
282 |
-
gen_kwargs = dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
return_full_text=True,
|
284 |
-
|
285 |
-
|
286 |
|
287 |
if stream_output:
|
288 |
skip_prompt = False
|
@@ -297,10 +755,12 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
297 |
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
|
298 |
prompter=prompter,
|
299 |
prompt_type=prompt_type,
|
300 |
-
|
|
|
301 |
chat=False, stream_output=stream_output,
|
302 |
tokenizer=tokenizer,
|
303 |
-
|
|
|
304 |
**gen_kwargs)
|
305 |
# pipe.task = "text-generation"
|
306 |
# below makes it listen only to our prompt removal,
|
@@ -345,7 +805,7 @@ def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
|
|
345 |
data = json.load(open(filename, "rt"))
|
346 |
page_content = list(data["query"]["pages"].values())[0]["extract"]
|
347 |
if take_head is not None and text_limit is not None:
|
348 |
-
page_content = page_content[:text_limit] if take_head else page_content[
|
349 |
title_url = str(title).replace(' ', '_')
|
350 |
return Document(
|
351 |
page_content=page_content,
|
@@ -467,6 +927,21 @@ try:
|
|
467 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
468 |
have_pymupdf = False
|
469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
image_types = ["png", "jpg", "jpeg"]
|
471 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
472 |
"md", "html",
|
@@ -484,12 +959,13 @@ file_types = non_image_types + image_types
|
|
484 |
def add_meta(docs1, file):
|
485 |
file_extension = pathlib.Path(file).suffix
|
486 |
hashid = hash_file(file)
|
487 |
-
if not isinstance(docs1, list):
|
488 |
docs1 = [docs1]
|
489 |
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
|
490 |
|
491 |
|
492 |
-
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
|
493 |
is_url=False, is_txt=False,
|
494 |
enable_captions=True,
|
495 |
captions_model=None,
|
@@ -525,9 +1001,25 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
525 |
else:
|
526 |
docs1 = []
|
527 |
else:
|
|
|
|
|
528 |
docs1 = UnstructuredURLLoader(urls=[file]).load()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
[x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
|
530 |
-
|
|
|
531 |
elif is_txt:
|
532 |
base_path = "user_paste"
|
533 |
source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
|
@@ -536,44 +1028,49 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
536 |
f.write(file)
|
537 |
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
|
538 |
doc1 = Document(page_content=file, metadata=metadata)
|
|
|
539 |
elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
|
540 |
docs1 = UnstructuredHTMLLoader(file_path=file).load()
|
541 |
add_meta(docs1, file)
|
542 |
-
|
|
|
543 |
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
|
544 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
545 |
add_meta(docs1, file)
|
546 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
547 |
elif file.lower().endswith('.odt'):
|
548 |
docs1 = UnstructuredODTLoader(file_path=file).load()
|
549 |
add_meta(docs1, file)
|
550 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
551 |
elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
|
552 |
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
|
553 |
add_meta(docs1, file)
|
554 |
-
|
|
|
555 |
elif file.lower().endswith('.txt'):
|
556 |
# use UnstructuredFileLoader ?
|
557 |
docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
|
558 |
# makes just one, but big one
|
559 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
|
|
560 |
add_meta(doc1, file)
|
561 |
elif file.lower().endswith('.rtf'):
|
562 |
docs1 = UnstructuredRTFLoader(file).load()
|
563 |
add_meta(docs1, file)
|
564 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
565 |
elif file.lower().endswith('.md'):
|
566 |
docs1 = UnstructuredMarkdownLoader(file).load()
|
567 |
add_meta(docs1, file)
|
568 |
-
|
|
|
569 |
elif file.lower().endswith('.enex'):
|
570 |
docs1 = EverNoteLoader(file).load()
|
571 |
add_meta(doc1, file)
|
572 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
573 |
elif file.lower().endswith('.epub'):
|
574 |
docs1 = UnstructuredEPubLoader(file).load()
|
575 |
add_meta(docs1, file)
|
576 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
577 |
elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
|
578 |
docs1 = []
|
579 |
if have_tesseract and enable_ocr:
|
@@ -603,7 +1100,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
603 |
doci.metadata['source'] = doci.metadata['image_path']
|
604 |
doci.metadata['hash'] = hash_file(doci.metadata['source'])
|
605 |
if docs1:
|
606 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
607 |
elif file.lower().endswith('.msg'):
|
608 |
raise RuntimeError("Not supported, GPL3 license")
|
609 |
# docs1 = OutlookMessageLoader(file).load()
|
@@ -612,14 +1109,14 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
612 |
try:
|
613 |
docs1 = UnstructuredEmailLoader(file).load()
|
614 |
add_meta(docs1, file)
|
615 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
616 |
except ValueError as e:
|
617 |
if 'text/html content not found in email' in str(e):
|
618 |
# e.g. plain/text dict key exists, but not
|
619 |
# doc1 = TextLoader(file, encoding="utf8").load()
|
620 |
docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
|
621 |
add_meta(docs1, file)
|
622 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
623 |
else:
|
624 |
raise
|
625 |
# elif file.lower().endswith('.gcsdir'):
|
@@ -630,6 +1127,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
630 |
with open(file, "r") as f:
|
631 |
doc1 = Document(page_content=f.read(), metadata={"source": file})
|
632 |
add_meta(doc1, file)
|
|
|
633 |
elif file.lower().endswith('.pdf'):
|
634 |
env_gpt4all_file = ".env_gpt4all"
|
635 |
from dotenv import dotenv_values
|
@@ -638,11 +1136,19 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
638 |
if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
|
639 |
# GPL, only use if installed
|
640 |
from langchain.document_loaders import PyMuPDFLoader
|
641 |
-
|
|
|
|
|
|
|
|
|
|
|
642 |
else:
|
643 |
# open-source fallback
|
644 |
-
|
|
|
|
|
645 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
|
|
646 |
add_meta(doc1, file)
|
647 |
elif file.lower().endswith('.csv'):
|
648 |
doc1 = CSVLoader(file).load()
|
@@ -650,6 +1156,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
650 |
elif file.lower().endswith('.py'):
|
651 |
doc1 = PythonLoader(file).load()
|
652 |
add_meta(doc1, file)
|
|
|
653 |
elif file.lower().endswith('.toml'):
|
654 |
doc1 = TomlLoader(file).load()
|
655 |
add_meta(doc1, file)
|
@@ -657,7 +1164,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
657 |
with open(file, "r") as f:
|
658 |
docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
|
659 |
add_meta(docs1, file)
|
660 |
-
doc1 = chunk_sources(docs1, chunk_size=chunk_size)
|
661 |
elif file.lower().endswith('.zip'):
|
662 |
with zipfile.ZipFile(file, 'r') as zip_ref:
|
663 |
# don't put into temporary path, since want to keep references to docs inside zip
|
@@ -672,12 +1179,12 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
672 |
# if list of length one, don't trust and chunk it
|
673 |
if not isinstance(doc1, list):
|
674 |
if chunk:
|
675 |
-
docs = chunk_sources([doc1], chunk_size=chunk_size)
|
676 |
else:
|
677 |
docs = [doc1]
|
678 |
elif isinstance(doc1, list) and len(doc1) == 1:
|
679 |
if chunk:
|
680 |
-
docs = chunk_sources(doc1, chunk_size=chunk_size)
|
681 |
else:
|
682 |
docs = doc1
|
683 |
else:
|
@@ -687,7 +1194,8 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
|
|
687 |
return docs
|
688 |
|
689 |
|
690 |
-
def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
|
|
|
691 |
is_url=False, is_txt=False,
|
692 |
enable_captions=True,
|
693 |
captions_model=None,
|
@@ -739,15 +1247,16 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
739 |
existing_files=[],
|
740 |
existing_hash_ids={},
|
741 |
):
|
|
|
742 |
globs_image_types = []
|
743 |
globs_non_image_types = []
|
744 |
if not path_or_paths and not url and not text:
|
745 |
return []
|
746 |
elif url:
|
747 |
-
globs_non_image_types = [url]
|
748 |
elif text:
|
749 |
-
globs_non_image_types = [text]
|
750 |
-
elif isinstance(path_or_paths, str):
|
751 |
# single path, only consume allowed files
|
752 |
path = path_or_paths
|
753 |
# Below globs should match patterns in file_to_doc()
|
@@ -756,8 +1265,11 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
756 |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
757 |
for ftype in non_image_types]
|
758 |
else:
|
|
|
|
|
759 |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
760 |
-
assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(
|
|
|
761 |
# reform out of allowed types
|
762 |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
763 |
# could do below:
|
@@ -861,12 +1373,12 @@ def prep_langchain(persist_directory,
|
|
861 |
|
862 |
if db_dir_exists and user_path is None:
|
863 |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
864 |
-
db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
865 |
hf_embedding_model)
|
866 |
else:
|
867 |
if db_dir_exists and user_path is not None:
|
868 |
print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
|
869 |
-
|
870 |
elif not db_dir_exists:
|
871 |
print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
|
872 |
db = None
|
@@ -912,24 +1424,78 @@ class FakeConsumer(object):
|
|
912 |
posthog.Consumer = FakeConsumer
|
913 |
|
914 |
|
915 |
-
def
|
916 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
917 |
if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
|
918 |
os.path.join(persist_directory, 'index')):
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
929 |
return db
|
930 |
return None
|
931 |
|
932 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
933 |
def make_db(**langchain_kwargs):
|
934 |
func_names = list(inspect.signature(_make_db).parameters)
|
935 |
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
|
@@ -945,9 +1511,33 @@ def make_db(**langchain_kwargs):
|
|
945 |
return _make_db(**langchain_kwargs)
|
946 |
|
947 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
948 |
def _make_db(use_openai_embedding=False,
|
949 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
950 |
-
first_para=False, text_limit=None,
|
|
|
951 |
langchain_mode=None,
|
952 |
user_path=None,
|
953 |
db_type='faiss',
|
@@ -955,19 +1545,13 @@ def _make_db(use_openai_embedding=False,
|
|
955 |
db=None,
|
956 |
n_jobs=-1,
|
957 |
verbose=False):
|
958 |
-
persist_directory =
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
-
|
963 |
-
|
964 |
-
|
965 |
-
client_settings = Settings(anonymized_telemetry=False,
|
966 |
-
chroma_db_impl="duckdb+parquet",
|
967 |
-
persist_directory=persist_directory)
|
968 |
-
db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
|
969 |
-
collection_name=langchain_mode.replace(' ', '_'),
|
970 |
-
client_settings=client_settings)
|
971 |
sources = []
|
972 |
if not db and langchain_mode not in ['MyData'] or \
|
973 |
user_path is not None and \
|
@@ -992,24 +1576,24 @@ def _make_db(use_openai_embedding=False,
|
|
992 |
sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
|
993 |
print("Got new wiki", flush=True)
|
994 |
if chunk:
|
995 |
-
sources1 = chunk_sources(sources1, chunk_size=chunk_size)
|
996 |
print("Chunked new wiki", flush=True)
|
997 |
sources.extend(sources1)
|
998 |
if langchain_mode in ['wiki', 'All', "'All'"]:
|
999 |
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
|
1000 |
if chunk:
|
1001 |
-
sources1 = chunk_sources(sources1, chunk_size=chunk_size)
|
1002 |
sources.extend(sources1)
|
1003 |
if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
|
1004 |
# sources = get_github_docs("dagster-io", "dagster")
|
1005 |
sources1 = get_github_docs("h2oai", "h2ogpt")
|
1006 |
# FIXME: always chunk for now
|
1007 |
-
sources1 = chunk_sources(sources1, chunk_size=chunk_size)
|
1008 |
sources.extend(sources1)
|
1009 |
if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
|
1010 |
sources1 = get_dai_docs(from_hf=True)
|
1011 |
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
|
1012 |
-
sources1 = chunk_sources(sources1, chunk_size=chunk_size)
|
1013 |
sources.extend(sources1)
|
1014 |
if langchain_mode in ['All', 'UserData']:
|
1015 |
if user_path:
|
@@ -1023,6 +1607,8 @@ def _make_db(use_openai_embedding=False,
|
|
1023 |
existing_files = []
|
1024 |
existing_hash_ids = []
|
1025 |
# chunk internally for speed over multiple docs
|
|
|
|
|
1026 |
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
|
1027 |
existing_files=existing_files, existing_hash_ids=existing_hash_ids)
|
1028 |
new_metadata_sources = set([x.metadata['source'] for x in sources1])
|
@@ -1066,7 +1652,9 @@ def _make_db(use_openai_embedding=False,
|
|
1066 |
new_sources_metadata = [x.metadata for x in sources]
|
1067 |
elif user_path is not None and langchain_mode in ['UserData']:
|
1068 |
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
|
1069 |
-
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type
|
|
|
|
|
1070 |
print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
|
1071 |
else:
|
1072 |
new_sources_metadata = [x.metadata for x in sources]
|
@@ -1074,63 +1662,140 @@ def _make_db(use_openai_embedding=False,
|
|
1074 |
return db, len(new_sources_metadata), new_sources_metadata
|
1075 |
|
1076 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1077 |
def get_existing_files(db):
|
1078 |
-
|
1079 |
-
metadata_sources = set([x['source'] for x in
|
1080 |
return metadata_sources
|
1081 |
|
1082 |
|
1083 |
def get_existing_hash_ids(db):
|
1084 |
-
|
1085 |
# assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
|
1086 |
-
metadata_hash_ids = {x['source']: x.get('hashid') for x in
|
1087 |
return metadata_hash_ids
|
1088 |
|
1089 |
|
1090 |
-
source_prefix = "Sources [Score | Link]:"
|
1091 |
-
source_postfix = "End Sources<p>"
|
1092 |
-
|
1093 |
-
|
1094 |
def run_qa_db(**kwargs):
|
1095 |
func_names = list(inspect.signature(_run_qa_db).parameters)
|
1096 |
# hard-coded defaults
|
1097 |
kwargs['answer_with_sources'] = True
|
1098 |
-
kwargs['sanitize_bot_response'] = True
|
1099 |
kwargs['show_rank'] = False
|
1100 |
missing_kwargs = [x for x in func_names if x not in kwargs]
|
1101 |
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
1102 |
# only keep actual used
|
1103 |
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
1104 |
-
|
|
|
|
|
|
|
1105 |
|
1106 |
|
1107 |
def _run_qa_db(query=None,
|
1108 |
use_openai_model=False, use_openai_embedding=False,
|
1109 |
-
first_para=False, text_limit=None,
|
1110 |
user_path=None,
|
1111 |
detect_user_path_changes_every_query=False,
|
1112 |
db_type='faiss',
|
1113 |
-
model_name=None, model=None, tokenizer=None,
|
1114 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1115 |
stream_output=False,
|
1116 |
prompter=None,
|
1117 |
prompt_type=None,
|
|
|
1118 |
answer_with_sources=True,
|
1119 |
cut_distanct=1.1,
|
1120 |
-
sanitize_bot_response=
|
1121 |
show_rank=False,
|
1122 |
load_db_if_exists=False,
|
1123 |
db=None,
|
1124 |
-
|
1125 |
temperature=0.1,
|
1126 |
-
repetition_penalty=1.0,
|
1127 |
top_k=40,
|
1128 |
top_p=0.7,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1129 |
langchain_mode=None,
|
1130 |
-
document_choice=[
|
1131 |
n_jobs=-1,
|
1132 |
verbose=False,
|
1133 |
-
cli=False
|
|
|
|
|
|
|
|
|
|
|
1134 |
"""
|
1135 |
|
1136 |
:param query:
|
@@ -1149,39 +1814,63 @@ def _run_qa_db(query=None,
|
|
1149 |
:param answer_with_sources
|
1150 |
:return:
|
1151 |
"""
|
|
|
|
|
1152 |
assert query is not None
|
1153 |
assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
|
1154 |
if prompter is not None:
|
1155 |
prompt_type = prompter.prompt_type
|
|
|
1156 |
if model is not None:
|
1157 |
assert prompt_type is not None
|
|
|
|
|
|
|
|
|
|
|
1158 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1159 |
-
model=model,
|
|
|
|
|
1160 |
stream_output=stream_output,
|
1161 |
-
|
1162 |
temperature=temperature,
|
1163 |
-
repetition_penalty=repetition_penalty,
|
1164 |
top_k=top_k,
|
1165 |
top_p=top_p,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1166 |
prompt_type=prompt_type,
|
|
|
1167 |
prompter=prompter,
|
|
|
1168 |
verbose=verbose,
|
1169 |
)
|
1170 |
|
1171 |
-
if model_name in non_hf_types:
|
1172 |
-
# FIXME: for now, streams to stdout/stderr currently
|
1173 |
-
stream_output = False
|
1174 |
-
|
1175 |
use_context = False
|
1176 |
scores = []
|
1177 |
chain = None
|
1178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1179 |
func_names = list(inspect.signature(get_similarity_chain).parameters)
|
1180 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1181 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1182 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1183 |
docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
|
1184 |
-
if
|
1185 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1186 |
yield formatted_doc_chunks, ''
|
1187 |
return
|
@@ -1189,43 +1878,49 @@ def _run_qa_db(query=None,
|
|
1189 |
# can only return if HF type
|
1190 |
return
|
1191 |
|
1192 |
-
|
1193 |
-
|
1194 |
-
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
-
|
1200 |
-
|
1201 |
-
|
1202 |
-
|
1203 |
-
|
1204 |
-
|
1205 |
-
|
1206 |
-
|
1207 |
-
|
1208 |
-
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
|
1213 |
-
|
1214 |
-
|
1215 |
-
|
1216 |
-
|
1217 |
-
|
1218 |
-
|
1219 |
-
|
1220 |
-
|
1221 |
-
|
1222 |
-
|
1223 |
-
|
1224 |
-
|
1225 |
-
|
1226 |
-
|
1227 |
-
|
1228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1229 |
|
1230 |
if not use_context:
|
1231 |
ret = answer['output_text']
|
@@ -1239,22 +1934,31 @@ def _run_qa_db(query=None,
|
|
1239 |
|
1240 |
def get_similarity_chain(query=None,
|
1241 |
use_openai_model=False, use_openai_embedding=False,
|
1242 |
-
first_para=False, text_limit=None,
|
1243 |
user_path=None,
|
1244 |
detect_user_path_changes_every_query=False,
|
1245 |
db_type='faiss',
|
1246 |
model_name=None,
|
|
|
1247 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1248 |
prompt_type=None,
|
|
|
1249 |
cut_distanct=1.1,
|
1250 |
load_db_if_exists=False,
|
1251 |
db=None,
|
1252 |
langchain_mode=None,
|
1253 |
-
document_choice=[
|
1254 |
n_jobs=-1,
|
1255 |
# beyond run_db_query:
|
1256 |
llm=None,
|
|
|
1257 |
verbose=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
1258 |
):
|
1259 |
# determine whether use of context out of docs is planned
|
1260 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
@@ -1266,10 +1970,14 @@ def get_similarity_chain(query=None,
|
|
1266 |
use_context = True
|
1267 |
|
1268 |
# https://github.com/hwchase17/langchain/issues/1946
|
1269 |
-
# FIXME: Seems to way to get size of chroma db to limit
|
1270 |
# Chroma collection MyData contains fewer than 4 elements.
|
1271 |
# type logger error
|
1272 |
-
|
|
|
|
|
|
|
|
|
1273 |
|
1274 |
# FIXME: For All just go over all dbs instead of a separate db for All
|
1275 |
if not detect_user_path_changes_every_query and db is not None:
|
@@ -1279,7 +1987,8 @@ def get_similarity_chain(query=None,
|
|
1279 |
user_path = None
|
1280 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
|
1281 |
hf_embedding_model=hf_embedding_model,
|
1282 |
-
first_para=first_para, text_limit=text_limit,
|
|
|
1283 |
chunk_size=chunk_size,
|
1284 |
langchain_mode=langchain_mode,
|
1285 |
user_path=user_path,
|
@@ -1289,37 +1998,133 @@ def get_similarity_chain(query=None,
|
|
1289 |
n_jobs=n_jobs,
|
1290 |
verbose=verbose)
|
1291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1292 |
if db and use_context:
|
1293 |
-
if isinstance(
|
1294 |
-
#
|
1295 |
-
document_choice = [document_choice]
|
1296 |
-
if not isinstance(db, Chroma) or \
|
1297 |
-
len(document_choice) == 0 or \
|
1298 |
-
len(document_choice) <= 1 and document_choice[0] == 'All':
|
1299 |
-
# treat empty list as All for now, not 'None'
|
1300 |
-
filter_kwargs = {}
|
1301 |
-
elif len(document_choice) > 0 and document_choice[0] == 'Only':
|
1302 |
-
# Only means All docs, but only will return sources, not LLM response
|
1303 |
filter_kwargs = {}
|
1304 |
else:
|
|
|
1305 |
if len(document_choice) >= 2:
|
1306 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
1307 |
filter_kwargs = dict(filter={"$or": or_filter})
|
1308 |
-
elif len(document_choice)
|
|
|
1309 |
one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
|
1310 |
filter_kwargs = dict(filter=one_filter)
|
1311 |
else:
|
|
|
1312 |
filter_kwargs = {}
|
1313 |
-
|
1314 |
-
|
1315 |
-
|
1316 |
-
|
1317 |
-
|
1318 |
-
|
1319 |
-
|
1320 |
-
|
1321 |
-
|
1322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1323 |
else:
|
1324 |
docs = []
|
1325 |
scores = []
|
@@ -1328,7 +2133,7 @@ def get_similarity_chain(query=None,
|
|
1328 |
# if HF type and have no docs, can bail out
|
1329 |
return docs, None, [], False
|
1330 |
|
1331 |
-
if
|
1332 |
# no LLM use
|
1333 |
return docs, None, [], False
|
1334 |
|
@@ -1348,19 +2153,11 @@ def get_similarity_chain(query=None,
|
|
1348 |
if len(docs) == 0:
|
1349 |
# avoid context == in prompt then
|
1350 |
use_context = False
|
|
|
1351 |
|
1352 |
-
if
|
1353 |
# instruct-like, rather than few-shot prompt_type='plain' as default
|
1354 |
# but then sources confuse the model with how inserted among rest of text, so avoid
|
1355 |
-
prefix = ""
|
1356 |
-
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
|
1357 |
-
template = """%s{context}{question}""" % prefix
|
1358 |
-
else:
|
1359 |
-
template = """%s
|
1360 |
-
==
|
1361 |
-
{context}
|
1362 |
-
==
|
1363 |
-
{question}""" % prefix
|
1364 |
prompt = PromptTemplate(
|
1365 |
# input_variables=["summaries", "question"],
|
1366 |
input_variables=["context", "question"],
|
@@ -1420,15 +2217,32 @@ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, ve
|
|
1420 |
return ret, extra
|
1421 |
|
1422 |
|
1423 |
-
def
|
1424 |
-
|
1425 |
-
|
1426 |
-
|
1427 |
-
|
1428 |
-
|
1429 |
-
|
1430 |
-
|
1431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1432 |
return source_chunks
|
1433 |
|
1434 |
|
@@ -1439,6 +2253,8 @@ def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'):
|
|
1439 |
path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
|
1440 |
import zipfile
|
1441 |
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
|
|
|
|
|
1442 |
zip_ref.extractall(dest)
|
1443 |
return path_to_zip_file
|
1444 |
|
@@ -1467,5 +2283,28 @@ def get_some_dbs_from_hf(dest='.', db_zips=None):
|
|
1467 |
assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
|
1468 |
|
1469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1470 |
if __name__ == '__main__':
|
1471 |
pass
|
|
|
1 |
+
import ast
|
2 |
import glob
|
3 |
import inspect
|
4 |
import os
|
5 |
import pathlib
|
6 |
import pickle
|
|
|
7 |
import shutil
|
8 |
import subprocess
|
|
|
9 |
import tempfile
|
10 |
+
import time
|
11 |
import traceback
|
12 |
+
import types
|
13 |
import uuid
|
14 |
import zipfile
|
15 |
from collections import defaultdict
|
16 |
from datetime import datetime
|
17 |
from functools import reduce
|
18 |
from operator import concat
|
19 |
+
import filelock
|
20 |
|
21 |
+
from joblib import delayed
|
22 |
+
from langchain.callbacks import streaming_stdout
|
23 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
+
from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
|
27 |
+
from generate import gen_hyper, get_model, SEED
|
28 |
+
from prompter import non_hf_types, PromptType, Prompter
|
29 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
30 |
+
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
|
31 |
+
from utils_langchain import StreamingGradioCallbackHandler
|
32 |
|
33 |
import_matplotlib()
|
34 |
|
|
|
43 |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
44 |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
45 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
46 |
+
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
|
47 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
|
48 |
from langchain.chains.question_answering import load_qa_chain
|
49 |
from langchain.docstore.document import Document
|
50 |
+
from langchain import PromptTemplate, HuggingFaceTextGenInference
|
51 |
from langchain.vectorstores import Chroma
|
52 |
|
53 |
|
54 |
+
def get_db(sources, use_openai_embedding=False, db_type='faiss',
|
55 |
+
persist_directory="db_dir", load_db_if_exists=True,
|
56 |
+
langchain_mode='notset',
|
57 |
collection_name=None,
|
58 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
|
59 |
if not sources:
|
60 |
return None
|
61 |
+
|
62 |
# get embedding model
|
63 |
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
64 |
assert collection_name is not None or langchain_mode != 'notset'
|
|
|
69 |
if db_type == 'faiss':
|
70 |
from langchain.vectorstores import FAISS
|
71 |
db = FAISS.from_documents(sources, embedding)
|
|
|
72 |
elif db_type == 'weaviate':
|
73 |
import weaviate
|
74 |
from weaviate.embedded import EmbeddedOptions
|
75 |
from langchain.vectorstores import Weaviate
|
76 |
|
77 |
+
if os.getenv('WEAVIATE_URL', None):
|
78 |
+
client = _create_local_weaviate_client()
|
79 |
+
else:
|
80 |
+
client = weaviate.Client(
|
81 |
+
embedded_options=EmbeddedOptions()
|
82 |
+
)
|
83 |
index_name = collection_name.capitalize()
|
84 |
db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
|
85 |
index_name=index_name)
|
|
|
86 |
elif db_type == 'chroma':
|
87 |
assert persist_directory is not None
|
88 |
os.makedirs(persist_directory, exist_ok=True)
|
89 |
+
|
90 |
+
# see if already actually have persistent db, and deal with possible changes in embedding
|
91 |
+
db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
92 |
+
hf_embedding_model, verbose=False)
|
93 |
+
if db is None:
|
94 |
+
db = Chroma.from_documents(documents=sources,
|
95 |
+
embedding=embedding,
|
96 |
+
persist_directory=persist_directory,
|
97 |
+
collection_name=collection_name,
|
98 |
+
anonymized_telemetry=False)
|
99 |
+
db.persist()
|
100 |
+
clear_embedding(db)
|
101 |
+
save_embed(db, use_openai_embedding, hf_embedding_model)
|
102 |
+
else:
|
103 |
+
# then just add
|
104 |
+
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
|
105 |
+
use_openai_embedding=use_openai_embedding,
|
106 |
+
hf_embedding_model=hf_embedding_model)
|
107 |
else:
|
108 |
raise RuntimeError("No such db_type=%s" % db_type)
|
109 |
|
|
|
126 |
|
127 |
def add_to_db(db, sources, db_type='faiss',
|
128 |
avoid_dup_by_file=False,
|
129 |
+
avoid_dup_by_content=True,
|
130 |
+
use_openai_embedding=False,
|
131 |
+
hf_embedding_model=None):
|
132 |
+
assert hf_embedding_model is not None
|
133 |
num_new_sources = len(sources)
|
134 |
if not sources:
|
135 |
return db, num_new_sources, []
|
|
|
145 |
return db, num_new_sources, []
|
146 |
db.add_documents(documents=sources)
|
147 |
elif db_type == 'chroma':
|
148 |
+
collection = get_documents(db)
|
149 |
# files we already have:
|
150 |
metadata_files = set([x['source'] for x in collection['metadatas']])
|
151 |
if avoid_dup_by_file:
|
|
|
160 |
[x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
|
161 |
# avoid sources with same hash
|
162 |
sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
|
163 |
+
num_nohash = len([x for x in sources if not x.metadata.get('hashid')])
|
164 |
+
print("Found %s new sources (%d have no hash in original source,"
|
165 |
+
" so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True)
|
166 |
# get new file names that match existing file names. delete existing files we are overridding
|
167 |
dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
|
168 |
print("Removing %s duplicate files from db because ingesting those as new documents" % len(
|
169 |
dup_metadata_files), flush=True)
|
170 |
+
client_collection = db._client.get_collection(name=db._collection.name,
|
171 |
+
embedding_function=db._collection._embedding_function)
|
172 |
for dup_file in dup_metadata_files:
|
173 |
dup_file_meta = dict(source=dup_file)
|
174 |
try:
|
|
|
180 |
return db, num_new_sources, []
|
181 |
db.add_documents(documents=sources)
|
182 |
db.persist()
|
183 |
+
clear_embedding(db)
|
184 |
+
save_embed(db, use_openai_embedding, hf_embedding_model)
|
185 |
else:
|
186 |
raise RuntimeError("No such db_type=%s" % db_type)
|
187 |
|
|
|
196 |
import weaviate
|
197 |
from weaviate.embedded import EmbeddedOptions
|
198 |
|
199 |
+
if os.getenv('WEAVIATE_URL', None):
|
200 |
+
client = _create_local_weaviate_client()
|
201 |
+
else:
|
202 |
+
client = weaviate.Client(
|
203 |
+
embedded_options=EmbeddedOptions()
|
204 |
+
)
|
205 |
+
|
206 |
index_name = collection_name.replace(' ', '_').capitalize()
|
207 |
if client.schema.exists(index_name) and not add_if_exists:
|
208 |
client.schema.delete_class(index_name)
|
|
|
239 |
if use_openai_embedding:
|
240 |
assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
|
241 |
from langchain.embeddings import OpenAIEmbeddings
|
242 |
+
embedding = OpenAIEmbeddings(disallowed_special=())
|
243 |
else:
|
244 |
# to ensure can fork without deadlock
|
245 |
from langchain.embeddings import HuggingFaceEmbeddings
|
246 |
|
247 |
device, torch_dtype, context_class = get_device_dtype()
|
248 |
model_kwargs = dict(device=device)
|
249 |
+
if 'instructor' in hf_embedding_model:
|
250 |
+
encode_kwargs = {'normalize_embeddings': True}
|
251 |
+
embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model,
|
252 |
+
model_kwargs=model_kwargs,
|
253 |
+
encode_kwargs=encode_kwargs)
|
254 |
+
else:
|
255 |
+
embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
|
256 |
return embedding
|
257 |
|
258 |
|
|
|
266 |
)["output_text"]
|
267 |
|
268 |
|
269 |
+
"""Wrapper around Huggingface text generation inference API."""
|
270 |
+
from functools import partial
|
271 |
+
from typing import Any, Dict, List, Optional, Set
|
272 |
+
|
273 |
+
from pydantic import Extra, Field, root_validator
|
274 |
+
|
275 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
276 |
+
|
277 |
+
"""Wrapper around Huggingface text generation inference API."""
|
278 |
+
from functools import partial
|
279 |
+
from typing import Any, Dict, List, Optional
|
280 |
+
|
281 |
+
from pydantic import Extra, Field, root_validator
|
282 |
+
|
283 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
284 |
+
from langchain.llms.base import LLM
|
285 |
+
|
286 |
+
|
287 |
+
class GradioInference(LLM):
|
288 |
+
"""
|
289 |
+
Gradio generation inference API.
|
290 |
+
"""
|
291 |
+
inference_server_url: str = ""
|
292 |
+
|
293 |
+
temperature: float = 0.8
|
294 |
+
top_p: Optional[float] = 0.95
|
295 |
+
top_k: Optional[int] = None
|
296 |
+
num_beams: Optional[int] = 1
|
297 |
+
max_new_tokens: int = 512
|
298 |
+
min_new_tokens: int = 1
|
299 |
+
early_stopping: bool = False
|
300 |
+
max_time: int = 180
|
301 |
+
repetition_penalty: Optional[float] = None
|
302 |
+
num_return_sequences: Optional[int] = 1
|
303 |
+
do_sample: bool = False
|
304 |
+
chat_client: bool = False
|
305 |
+
|
306 |
+
return_full_text: bool = True
|
307 |
+
stream: bool = False
|
308 |
+
sanitize_bot_response: bool = False
|
309 |
+
|
310 |
+
prompter: Any = None
|
311 |
+
client: Any = None
|
312 |
+
|
313 |
+
class Config:
|
314 |
+
"""Configuration for this pydantic object."""
|
315 |
+
|
316 |
+
extra = Extra.forbid
|
317 |
+
|
318 |
+
@root_validator()
|
319 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
320 |
+
"""Validate that python package exists in environment."""
|
321 |
+
|
322 |
+
try:
|
323 |
+
if values['client'] is None:
|
324 |
+
import gradio_client
|
325 |
+
values["client"] = gradio_client.Client(
|
326 |
+
values["inference_server_url"]
|
327 |
+
)
|
328 |
+
except ImportError:
|
329 |
+
raise ImportError(
|
330 |
+
"Could not import gradio_client python package. "
|
331 |
+
"Please install it with `pip install gradio_client`."
|
332 |
+
)
|
333 |
+
return values
|
334 |
+
|
335 |
+
@property
|
336 |
+
def _llm_type(self) -> str:
|
337 |
+
"""Return type of llm."""
|
338 |
+
return "gradio_inference"
|
339 |
+
|
340 |
+
def _call(
|
341 |
+
self,
|
342 |
+
prompt: str,
|
343 |
+
stop: Optional[List[str]] = None,
|
344 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
345 |
+
**kwargs: Any,
|
346 |
+
) -> str:
|
347 |
+
# NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection,
|
348 |
+
# so server should get prompt_type or '', not plain
|
349 |
+
# This is good, so gradio server can also handle stopping.py conditions
|
350 |
+
# this is different than TGI server that uses prompter to inject prompt_type prompting
|
351 |
+
stream_output = self.stream
|
352 |
+
gr_client = self.client
|
353 |
+
client_langchain_mode = 'Disabled'
|
354 |
+
top_k_docs = 1
|
355 |
+
chunk = True
|
356 |
+
chunk_size = 512
|
357 |
+
client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
|
358 |
+
iinput='', # only for chat=True
|
359 |
+
context='',
|
360 |
+
# streaming output is supported, loops over and outputs each generation in streaming mode
|
361 |
+
# but leave stream_output=False for simple input/output mode
|
362 |
+
stream_output=stream_output,
|
363 |
+
prompt_type=self.prompter.prompt_type,
|
364 |
+
prompt_dict='',
|
365 |
+
|
366 |
+
temperature=self.temperature,
|
367 |
+
top_p=self.top_p,
|
368 |
+
top_k=self.top_k,
|
369 |
+
num_beams=self.num_beams,
|
370 |
+
max_new_tokens=self.max_new_tokens,
|
371 |
+
min_new_tokens=self.min_new_tokens,
|
372 |
+
early_stopping=self.early_stopping,
|
373 |
+
max_time=self.max_time,
|
374 |
+
repetition_penalty=self.repetition_penalty,
|
375 |
+
num_return_sequences=self.num_return_sequences,
|
376 |
+
do_sample=self.do_sample,
|
377 |
+
chat=self.chat_client,
|
378 |
+
|
379 |
+
instruction_nochat=prompt if not self.chat_client else '',
|
380 |
+
iinput_nochat='', # only for chat=False
|
381 |
+
langchain_mode=client_langchain_mode,
|
382 |
+
top_k_docs=top_k_docs,
|
383 |
+
chunk=chunk,
|
384 |
+
chunk_size=chunk_size,
|
385 |
+
document_choice=[DocumentChoices.All_Relevant.name],
|
386 |
+
)
|
387 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
388 |
+
if not stream_output:
|
389 |
+
res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
|
390 |
+
res_dict = ast.literal_eval(res)
|
391 |
+
text = res_dict['response']
|
392 |
+
return self.prompter.get_response(prompt + text, prompt=prompt,
|
393 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
394 |
+
else:
|
395 |
+
text_callback = None
|
396 |
+
if run_manager:
|
397 |
+
text_callback = partial(
|
398 |
+
run_manager.on_llm_new_token, verbose=self.verbose
|
399 |
+
)
|
400 |
+
|
401 |
+
job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
|
402 |
+
text0 = ''
|
403 |
+
while not job.done():
|
404 |
+
outputs_list = job.communicator.job.outputs
|
405 |
+
if outputs_list:
|
406 |
+
res = job.communicator.job.outputs[-1]
|
407 |
+
res_dict = ast.literal_eval(res)
|
408 |
+
text = res_dict['response']
|
409 |
+
text = self.prompter.get_response(prompt + text, prompt=prompt,
|
410 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
411 |
+
# FIXME: derive chunk from full for now
|
412 |
+
text_chunk = text[len(text0):]
|
413 |
+
# save old
|
414 |
+
text0 = text
|
415 |
+
|
416 |
+
if text_callback:
|
417 |
+
text_callback(text_chunk)
|
418 |
+
|
419 |
+
time.sleep(0.01)
|
420 |
+
|
421 |
+
# ensure get last output to avoid race
|
422 |
+
res_all = job.outputs()
|
423 |
+
if len(res_all) > 0:
|
424 |
+
res = res_all[-1]
|
425 |
+
res_dict = ast.literal_eval(res)
|
426 |
+
text = res_dict['response']
|
427 |
+
# FIXME: derive chunk from full for now
|
428 |
+
else:
|
429 |
+
# go with old if failure
|
430 |
+
text = text0
|
431 |
+
text_chunk = text[len(text0):]
|
432 |
+
if text_callback:
|
433 |
+
text_callback(text_chunk)
|
434 |
+
return self.prompter.get_response(prompt + text, prompt=prompt,
|
435 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
436 |
+
|
437 |
+
|
438 |
+
class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
439 |
+
max_new_tokens: int = 512
|
440 |
+
do_sample: bool = False
|
441 |
+
top_k: Optional[int] = None
|
442 |
+
top_p: Optional[float] = 0.95
|
443 |
+
typical_p: Optional[float] = 0.95
|
444 |
+
temperature: float = 0.8
|
445 |
+
repetition_penalty: Optional[float] = None
|
446 |
+
return_full_text: bool = False
|
447 |
+
stop_sequences: List[str] = Field(default_factory=list)
|
448 |
+
seed: Optional[int] = None
|
449 |
+
inference_server_url: str = ""
|
450 |
+
timeout: int = 300
|
451 |
+
headers: dict = None
|
452 |
+
stream: bool = False
|
453 |
+
sanitize_bot_response: bool = False
|
454 |
+
prompter: Any = None
|
455 |
+
tokenizer: Any = None
|
456 |
+
client: Any = None
|
457 |
+
|
458 |
+
@root_validator()
|
459 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
460 |
+
"""Validate that python package exists in environment."""
|
461 |
+
|
462 |
+
try:
|
463 |
+
if values['client'] is None:
|
464 |
+
import text_generation
|
465 |
+
|
466 |
+
values["client"] = text_generation.Client(
|
467 |
+
values["inference_server_url"],
|
468 |
+
timeout=values["timeout"],
|
469 |
+
headers=values["headers"],
|
470 |
+
)
|
471 |
+
except ImportError:
|
472 |
+
raise ImportError(
|
473 |
+
"Could not import text_generation python package. "
|
474 |
+
"Please install it with `pip install text_generation`."
|
475 |
+
)
|
476 |
+
return values
|
477 |
+
|
478 |
+
def _call(
|
479 |
+
self,
|
480 |
+
prompt: str,
|
481 |
+
stop: Optional[List[str]] = None,
|
482 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
483 |
+
**kwargs: Any,
|
484 |
+
) -> str:
|
485 |
+
if stop is None:
|
486 |
+
stop = self.stop_sequences
|
487 |
+
else:
|
488 |
+
stop += self.stop_sequences
|
489 |
+
|
490 |
+
# HF inference server needs control over input tokens
|
491 |
+
assert self.tokenizer is not None
|
492 |
+
from h2oai_pipeline import H2OTextGenerationPipeline
|
493 |
+
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
494 |
+
|
495 |
+
# NOTE: TGI server does not add prompting, so must do here
|
496 |
+
data_point = dict(context='', instruction=prompt, input='')
|
497 |
+
prompt = self.prompter.generate_prompt(data_point)
|
498 |
+
|
499 |
+
gen_server_kwargs = dict(do_sample=self.do_sample,
|
500 |
+
stop_sequences=stop,
|
501 |
+
max_new_tokens=self.max_new_tokens,
|
502 |
+
top_k=self.top_k,
|
503 |
+
top_p=self.top_p,
|
504 |
+
typical_p=self.typical_p,
|
505 |
+
temperature=self.temperature,
|
506 |
+
repetition_penalty=self.repetition_penalty,
|
507 |
+
return_full_text=self.return_full_text,
|
508 |
+
seed=self.seed,
|
509 |
+
)
|
510 |
+
gen_server_kwargs.update(kwargs)
|
511 |
+
|
512 |
+
# lower bound because client is re-used if multi-threading
|
513 |
+
self.client.timeout = max(300, self.timeout)
|
514 |
+
|
515 |
+
if not self.stream:
|
516 |
+
res = self.client.generate(
|
517 |
+
prompt,
|
518 |
+
**gen_server_kwargs,
|
519 |
+
)
|
520 |
+
if self.return_full_text:
|
521 |
+
gen_text = res.generated_text[len(prompt):]
|
522 |
+
else:
|
523 |
+
gen_text = res.generated_text
|
524 |
+
# remove stop sequences from the end of the generated text
|
525 |
+
for stop_seq in stop:
|
526 |
+
if stop_seq in gen_text:
|
527 |
+
gen_text = gen_text[:gen_text.index(stop_seq)]
|
528 |
+
text = prompt + gen_text
|
529 |
+
text = self.prompter.get_response(text, prompt=prompt,
|
530 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
531 |
+
else:
|
532 |
+
text_callback = None
|
533 |
+
if run_manager:
|
534 |
+
text_callback = partial(
|
535 |
+
run_manager.on_llm_new_token, verbose=self.verbose
|
536 |
+
)
|
537 |
+
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
538 |
+
if text_callback:
|
539 |
+
text_callback(prompt)
|
540 |
+
text = ""
|
541 |
+
# Note: Streaming ignores return_full_text=True
|
542 |
+
for response in self.client.generate_stream(prompt, **gen_server_kwargs):
|
543 |
+
text_chunk = response.token.text
|
544 |
+
text += text_chunk
|
545 |
+
text = self.prompter.get_response(prompt + text, prompt=prompt,
|
546 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
547 |
+
# stream part
|
548 |
+
is_stop = False
|
549 |
+
for stop_seq in stop:
|
550 |
+
if stop_seq in response.token.text:
|
551 |
+
is_stop = True
|
552 |
+
break
|
553 |
+
if is_stop:
|
554 |
+
break
|
555 |
+
if not response.token.special:
|
556 |
+
if text_callback:
|
557 |
+
text_callback(response.token.text)
|
558 |
+
return text
|
559 |
+
|
560 |
+
|
561 |
+
from langchain.chat_models import ChatOpenAI
|
562 |
+
|
563 |
+
|
564 |
+
class H2OChatOpenAI(ChatOpenAI):
|
565 |
+
@classmethod
|
566 |
+
def all_required_field_names(cls) -> Set:
|
567 |
+
all_required_field_names = super(ChatOpenAI, cls).all_required_field_names()
|
568 |
+
all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty'})
|
569 |
+
return all_required_field_names
|
570 |
+
|
571 |
+
|
572 |
+
def get_llm(use_openai_model=False,
|
573 |
+
model_name=None,
|
574 |
+
model=None,
|
575 |
+
tokenizer=None,
|
576 |
+
inference_server=None,
|
577 |
+
stream_output=False,
|
578 |
+
do_sample=False,
|
579 |
temperature=0.1,
|
|
|
580 |
top_k=40,
|
581 |
top_p=0.7,
|
582 |
+
num_beams=1,
|
583 |
+
max_new_tokens=256,
|
584 |
+
min_new_tokens=1,
|
585 |
+
early_stopping=False,
|
586 |
+
max_time=180,
|
587 |
+
repetition_penalty=1.0,
|
588 |
+
num_return_sequences=1,
|
589 |
prompt_type=None,
|
590 |
+
prompt_dict=None,
|
591 |
prompter=None,
|
592 |
+
sanitize_bot_response=False,
|
593 |
verbose=False,
|
594 |
):
|
595 |
+
if use_openai_model or inference_server in ['openai', 'openai_chat']:
|
596 |
+
if use_openai_model and model_name is None:
|
597 |
+
model_name = "gpt-3.5-turbo"
|
598 |
+
if inference_server == 'openai':
|
599 |
+
from langchain.llms import OpenAI
|
600 |
+
cls = OpenAI
|
601 |
+
else:
|
602 |
+
cls = H2OChatOpenAI
|
603 |
+
callbacks = [StreamingGradioCallbackHandler()]
|
604 |
+
llm = cls(model_name=model_name,
|
605 |
+
temperature=temperature if do_sample else 0,
|
606 |
+
# FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py
|
607 |
+
max_tokens=max_new_tokens,
|
608 |
+
top_p=top_p if do_sample else 1,
|
609 |
+
frequency_penalty=0,
|
610 |
+
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
611 |
+
callbacks=callbacks if stream_output else None,
|
612 |
+
)
|
613 |
+
streamer = callbacks[0] if stream_output else None
|
614 |
+
if inference_server in ['openai', 'openai_chat']:
|
615 |
+
prompt_type = inference_server
|
616 |
+
else:
|
617 |
+
prompt_type = prompt_type or 'plain'
|
618 |
+
elif inference_server:
|
619 |
+
assert inference_server.startswith(
|
620 |
+
'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server
|
621 |
+
|
622 |
+
from gradio_utils.grclient import GradioClient
|
623 |
+
from text_generation import Client as HFClient
|
624 |
+
if isinstance(model, GradioClient):
|
625 |
+
gr_client = model
|
626 |
+
hf_client = None
|
627 |
+
else:
|
628 |
+
gr_client = None
|
629 |
+
hf_client = model
|
630 |
+
assert isinstance(hf_client, HFClient)
|
631 |
+
|
632 |
+
inference_server, headers = get_hf_server(inference_server)
|
633 |
+
|
634 |
+
# quick sanity check to avoid long timeouts, just see if can reach server
|
635 |
+
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
|
636 |
+
|
637 |
+
callbacks = [StreamingGradioCallbackHandler()]
|
638 |
+
assert prompter is not None
|
639 |
+
stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
|
640 |
+
|
641 |
+
if gr_client:
|
642 |
+
chat_client = False
|
643 |
+
llm = GradioInference(
|
644 |
+
inference_server_url=inference_server,
|
645 |
+
return_full_text=True,
|
646 |
+
|
647 |
+
temperature=temperature,
|
648 |
+
top_p=top_p,
|
649 |
+
top_k=top_k,
|
650 |
+
num_beams=num_beams,
|
651 |
+
max_new_tokens=max_new_tokens,
|
652 |
+
min_new_tokens=min_new_tokens,
|
653 |
+
early_stopping=early_stopping,
|
654 |
+
max_time=max_time,
|
655 |
+
repetition_penalty=repetition_penalty,
|
656 |
+
num_return_sequences=num_return_sequences,
|
657 |
+
do_sample=do_sample,
|
658 |
+
chat_client=chat_client,
|
659 |
+
|
660 |
+
callbacks=callbacks if stream_output else None,
|
661 |
+
stream=stream_output,
|
662 |
+
prompter=prompter,
|
663 |
+
client=gr_client,
|
664 |
+
sanitize_bot_response=sanitize_bot_response,
|
665 |
+
)
|
666 |
+
elif hf_client:
|
667 |
+
llm = H2OHuggingFaceTextGenInference(
|
668 |
+
inference_server_url=inference_server,
|
669 |
+
do_sample=do_sample,
|
670 |
+
max_new_tokens=max_new_tokens,
|
671 |
+
repetition_penalty=repetition_penalty,
|
672 |
+
return_full_text=True,
|
673 |
+
seed=SEED,
|
674 |
+
|
675 |
+
stop_sequences=stop_sequences,
|
676 |
+
temperature=temperature,
|
677 |
+
top_k=top_k,
|
678 |
+
top_p=top_p,
|
679 |
+
# typical_p=top_p,
|
680 |
+
callbacks=callbacks if stream_output else None,
|
681 |
+
stream=stream_output,
|
682 |
+
prompter=prompter,
|
683 |
+
tokenizer=tokenizer,
|
684 |
+
client=hf_client,
|
685 |
+
timeout=max_time,
|
686 |
+
sanitize_bot_response=sanitize_bot_response,
|
687 |
+
)
|
688 |
+
else:
|
689 |
+
raise RuntimeError("No defined client")
|
690 |
+
streamer = callbacks[0] if stream_output else None
|
691 |
elif model_name in non_hf_types:
|
692 |
+
if model_name == 'llama':
|
693 |
+
callbacks = [StreamingGradioCallbackHandler()]
|
694 |
+
streamer = callbacks[0] if stream_output else None
|
695 |
+
else:
|
696 |
+
# stream_output = False
|
697 |
+
# doesn't stream properly as generator, but at least
|
698 |
+
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
|
699 |
+
streamer = None
|
700 |
+
if prompter:
|
701 |
+
prompt_type = prompter.prompt_type
|
702 |
+
else:
|
703 |
+
prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output)
|
704 |
+
pass # assume inputted prompt_type is correct
|
705 |
from gpt4all_llm import get_llm_gpt4all
|
706 |
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
|
707 |
temperature=temperature,
|
708 |
repetition_penalty=repetition_penalty,
|
709 |
top_k=top_k,
|
710 |
top_p=top_p,
|
711 |
+
callbacks=callbacks,
|
712 |
verbose=verbose,
|
713 |
+
streaming=stream_output,
|
714 |
+
prompter=prompter,
|
715 |
)
|
|
|
|
|
716 |
else:
|
|
|
|
|
717 |
if model is None:
|
718 |
# only used if didn't pass model in
|
|
|
719 |
assert tokenizer is None
|
720 |
prompt_type = 'human_bot'
|
721 |
+
if model_name is None:
|
722 |
+
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
723 |
+
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
|
724 |
+
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
725 |
+
inference_server = ''
|
726 |
+
model, tokenizer, device = get_model(load_8bit=True, base_model=model_name,
|
727 |
+
inference_server=inference_server, gpu_id=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
728 |
|
729 |
max_max_tokens = tokenizer.model_max_length
|
730 |
+
gen_kwargs = dict(do_sample=do_sample,
|
731 |
+
temperature=temperature,
|
732 |
+
top_k=top_k,
|
733 |
+
top_p=top_p,
|
734 |
+
num_beams=num_beams,
|
735 |
+
max_new_tokens=max_new_tokens,
|
736 |
+
min_new_tokens=min_new_tokens,
|
737 |
+
early_stopping=early_stopping,
|
738 |
+
max_time=max_time,
|
739 |
+
repetition_penalty=repetition_penalty,
|
740 |
+
num_return_sequences=num_return_sequences,
|
741 |
return_full_text=True,
|
742 |
+
handle_long_generation=None)
|
743 |
+
assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
|
744 |
|
745 |
if stream_output:
|
746 |
skip_prompt = False
|
|
|
755 |
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
|
756 |
prompter=prompter,
|
757 |
prompt_type=prompt_type,
|
758 |
+
prompt_dict=prompt_dict,
|
759 |
+
sanitize_bot_response=sanitize_bot_response,
|
760 |
chat=False, stream_output=stream_output,
|
761 |
tokenizer=tokenizer,
|
762 |
+
# leave some room for 1 paragraph, even if min_new_tokens=0
|
763 |
+
max_input_tokens=max_max_tokens - max(min_new_tokens, 256),
|
764 |
**gen_kwargs)
|
765 |
# pipe.task = "text-generation"
|
766 |
# below makes it listen only to our prompt removal,
|
|
|
805 |
data = json.load(open(filename, "rt"))
|
806 |
page_content = list(data["query"]["pages"].values())[0]["extract"]
|
807 |
if take_head is not None and text_limit is not None:
|
808 |
+
page_content = page_content[:text_limit] if take_head else page_content[-text_limit:]
|
809 |
title_url = str(title).replace(' ', '_')
|
810 |
return Document(
|
811 |
page_content=page_content,
|
|
|
927 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
928 |
have_pymupdf = False
|
929 |
|
930 |
+
try:
|
931 |
+
assert pkg_resources.get_distribution('selenium') is not None
|
932 |
+
have_selenium = True
|
933 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
934 |
+
have_selenium = False
|
935 |
+
|
936 |
+
try:
|
937 |
+
assert pkg_resources.get_distribution('playwright') is not None
|
938 |
+
have_playwright = True
|
939 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
940 |
+
have_playwright = False
|
941 |
+
|
942 |
+
# disable, hangs too often
|
943 |
+
have_playwright = False
|
944 |
+
|
945 |
image_types = ["png", "jpg", "jpeg"]
|
946 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
947 |
"md", "html",
|
|
|
959 |
def add_meta(docs1, file):
|
960 |
file_extension = pathlib.Path(file).suffix
|
961 |
hashid = hash_file(file)
|
962 |
+
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
963 |
docs1 = [docs1]
|
964 |
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
|
965 |
|
966 |
|
967 |
+
def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
968 |
+
chunk=True, chunk_size=512,
|
969 |
is_url=False, is_txt=False,
|
970 |
enable_captions=True,
|
971 |
captions_model=None,
|
|
|
1001 |
else:
|
1002 |
docs1 = []
|
1003 |
else:
|
1004 |
+
if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")):
|
1005 |
+
file = 'http://' + file
|
1006 |
docs1 = UnstructuredURLLoader(urls=[file]).load()
|
1007 |
+
if len(docs1) == 0 and have_playwright:
|
1008 |
+
# then something went wrong, try another loader:
|
1009 |
+
from langchain.document_loaders import PlaywrightURLLoader
|
1010 |
+
docs1 = PlaywrightURLLoader(urls=[file]).load()
|
1011 |
+
if len(docs1) == 0 and have_selenium:
|
1012 |
+
# then something went wrong, try another loader:
|
1013 |
+
# but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: Message: unknown error: cannot find Chrome binary
|
1014 |
+
from langchain.document_loaders import SeleniumURLLoader
|
1015 |
+
from selenium.common.exceptions import WebDriverException
|
1016 |
+
try:
|
1017 |
+
docs1 = SeleniumURLLoader(urls=[file]).load()
|
1018 |
+
except WebDriverException as e:
|
1019 |
+
print("No web driver: %s" % str(e), flush=True)
|
1020 |
[x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
|
1021 |
+
docs1 = clean_doc(docs1)
|
1022 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1023 |
elif is_txt:
|
1024 |
base_path = "user_paste"
|
1025 |
source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
|
|
|
1028 |
f.write(file)
|
1029 |
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
|
1030 |
doc1 = Document(page_content=file, metadata=metadata)
|
1031 |
+
doc1 = clean_doc(doc1)
|
1032 |
elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
|
1033 |
docs1 = UnstructuredHTMLLoader(file_path=file).load()
|
1034 |
add_meta(docs1, file)
|
1035 |
+
docs1 = clean_doc(docs1)
|
1036 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
|
1037 |
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
|
1038 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1039 |
add_meta(docs1, file)
|
1040 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1041 |
elif file.lower().endswith('.odt'):
|
1042 |
docs1 = UnstructuredODTLoader(file_path=file).load()
|
1043 |
add_meta(docs1, file)
|
1044 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1045 |
elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
|
1046 |
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
|
1047 |
add_meta(docs1, file)
|
1048 |
+
docs1 = clean_doc(docs1)
|
1049 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1050 |
elif file.lower().endswith('.txt'):
|
1051 |
# use UnstructuredFileLoader ?
|
1052 |
docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
|
1053 |
# makes just one, but big one
|
1054 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1055 |
+
doc1 = clean_doc(doc1)
|
1056 |
add_meta(doc1, file)
|
1057 |
elif file.lower().endswith('.rtf'):
|
1058 |
docs1 = UnstructuredRTFLoader(file).load()
|
1059 |
add_meta(docs1, file)
|
1060 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1061 |
elif file.lower().endswith('.md'):
|
1062 |
docs1 = UnstructuredMarkdownLoader(file).load()
|
1063 |
add_meta(docs1, file)
|
1064 |
+
docs1 = clean_doc(docs1)
|
1065 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.MARKDOWN)
|
1066 |
elif file.lower().endswith('.enex'):
|
1067 |
docs1 = EverNoteLoader(file).load()
|
1068 |
add_meta(doc1, file)
|
1069 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1070 |
elif file.lower().endswith('.epub'):
|
1071 |
docs1 = UnstructuredEPubLoader(file).load()
|
1072 |
add_meta(docs1, file)
|
1073 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1074 |
elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
|
1075 |
docs1 = []
|
1076 |
if have_tesseract and enable_ocr:
|
|
|
1100 |
doci.metadata['source'] = doci.metadata['image_path']
|
1101 |
doci.metadata['hash'] = hash_file(doci.metadata['source'])
|
1102 |
if docs1:
|
1103 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1104 |
elif file.lower().endswith('.msg'):
|
1105 |
raise RuntimeError("Not supported, GPL3 license")
|
1106 |
# docs1 = OutlookMessageLoader(file).load()
|
|
|
1109 |
try:
|
1110 |
docs1 = UnstructuredEmailLoader(file).load()
|
1111 |
add_meta(docs1, file)
|
1112 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1113 |
except ValueError as e:
|
1114 |
if 'text/html content not found in email' in str(e):
|
1115 |
# e.g. plain/text dict key exists, but not
|
1116 |
# doc1 = TextLoader(file, encoding="utf8").load()
|
1117 |
docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
|
1118 |
add_meta(docs1, file)
|
1119 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1120 |
else:
|
1121 |
raise
|
1122 |
# elif file.lower().endswith('.gcsdir'):
|
|
|
1127 |
with open(file, "r") as f:
|
1128 |
doc1 = Document(page_content=f.read(), metadata={"source": file})
|
1129 |
add_meta(doc1, file)
|
1130 |
+
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.RST)
|
1131 |
elif file.lower().endswith('.pdf'):
|
1132 |
env_gpt4all_file = ".env_gpt4all"
|
1133 |
from dotenv import dotenv_values
|
|
|
1136 |
if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
|
1137 |
# GPL, only use if installed
|
1138 |
from langchain.document_loaders import PyMuPDFLoader
|
1139 |
+
# load() still chunks by pages, but every page has title at start to help
|
1140 |
+
doc1 = PyMuPDFLoader(file).load()
|
1141 |
+
doc1 = clean_doc(doc1)
|
1142 |
+
elif pdf_class_name == 'UnstructuredPDFLoader':
|
1143 |
+
doc1 = UnstructuredPDFLoader(file).load()
|
1144 |
+
# seems to not need cleaning in most cases
|
1145 |
else:
|
1146 |
# open-source fallback
|
1147 |
+
# load() still chunks by pages, but every page has title at start to help
|
1148 |
+
doc1 = PyPDFLoader(file).load()
|
1149 |
+
doc1 = clean_doc(doc1)
|
1150 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
1151 |
+
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
|
1152 |
add_meta(doc1, file)
|
1153 |
elif file.lower().endswith('.csv'):
|
1154 |
doc1 = CSVLoader(file).load()
|
|
|
1156 |
elif file.lower().endswith('.py'):
|
1157 |
doc1 = PythonLoader(file).load()
|
1158 |
add_meta(doc1, file)
|
1159 |
+
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.PYTHON)
|
1160 |
elif file.lower().endswith('.toml'):
|
1161 |
doc1 = TomlLoader(file).load()
|
1162 |
add_meta(doc1, file)
|
|
|
1164 |
with open(file, "r") as f:
|
1165 |
docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
|
1166 |
add_meta(docs1, file)
|
1167 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1168 |
elif file.lower().endswith('.zip'):
|
1169 |
with zipfile.ZipFile(file, 'r') as zip_ref:
|
1170 |
# don't put into temporary path, since want to keep references to docs inside zip
|
|
|
1179 |
# if list of length one, don't trust and chunk it
|
1180 |
if not isinstance(doc1, list):
|
1181 |
if chunk:
|
1182 |
+
docs = chunk_sources([doc1], chunk=chunk, chunk_size=chunk_size)
|
1183 |
else:
|
1184 |
docs = [doc1]
|
1185 |
elif isinstance(doc1, list) and len(doc1) == 1:
|
1186 |
if chunk:
|
1187 |
+
docs = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
|
1188 |
else:
|
1189 |
docs = doc1
|
1190 |
else:
|
|
|
1194 |
return docs
|
1195 |
|
1196 |
|
1197 |
+
def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
|
1198 |
+
chunk=True, chunk_size=512,
|
1199 |
is_url=False, is_txt=False,
|
1200 |
enable_captions=True,
|
1201 |
captions_model=None,
|
|
|
1247 |
existing_files=[],
|
1248 |
existing_hash_ids={},
|
1249 |
):
|
1250 |
+
# path_or_paths could be str, list, tuple, generator
|
1251 |
globs_image_types = []
|
1252 |
globs_non_image_types = []
|
1253 |
if not path_or_paths and not url and not text:
|
1254 |
return []
|
1255 |
elif url:
|
1256 |
+
globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url]
|
1257 |
elif text:
|
1258 |
+
globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text]
|
1259 |
+
elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths):
|
1260 |
# single path, only consume allowed files
|
1261 |
path = path_or_paths
|
1262 |
# Below globs should match patterns in file_to_doc()
|
|
|
1265 |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
1266 |
for ftype in non_image_types]
|
1267 |
else:
|
1268 |
+
if isinstance(path_or_paths, str) and (os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths)):
|
1269 |
+
path_or_paths = [path_or_paths]
|
1270 |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
1271 |
+
assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), "Wrong type for path_or_paths: %s" % type(
|
1272 |
+
path_or_paths)
|
1273 |
# reform out of allowed types
|
1274 |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
1275 |
# could do below:
|
|
|
1373 |
|
1374 |
if db_dir_exists and user_path is None:
|
1375 |
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
|
1376 |
+
db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
1377 |
hf_embedding_model)
|
1378 |
else:
|
1379 |
if db_dir_exists and user_path is not None:
|
1380 |
print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
|
1381 |
+
persist_directory, user_path), flush=True)
|
1382 |
elif not db_dir_exists:
|
1383 |
print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
|
1384 |
db = None
|
|
|
1424 |
posthog.Consumer = FakeConsumer
|
1425 |
|
1426 |
|
1427 |
+
def check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model, langchain_mode):
|
1428 |
+
changed_db = False
|
1429 |
+
if load_embed(db) != (use_openai_embedding, hf_embedding_model):
|
1430 |
+
print("Detected new embedding, updating db: %s" % langchain_mode, flush=True)
|
1431 |
+
# handle embedding changes
|
1432 |
+
db_get = get_documents(db)
|
1433 |
+
sources = [Document(page_content=result[0], metadata=result[1] or {})
|
1434 |
+
for result in zip(db_get['documents'], db_get['metadatas'])]
|
1435 |
+
# delete index, has to be redone
|
1436 |
+
persist_directory = db._persist_directory
|
1437 |
+
shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak")
|
1438 |
+
db_type = 'chroma'
|
1439 |
+
load_db_if_exists = False
|
1440 |
+
db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
|
1441 |
+
persist_directory=persist_directory, load_db_if_exists=load_db_if_exists,
|
1442 |
+
langchain_mode=langchain_mode,
|
1443 |
+
collection_name=None,
|
1444 |
+
hf_embedding_model=hf_embedding_model)
|
1445 |
+
if False:
|
1446 |
+
# below doesn't work if db already in memory, so have to switch to new db as above
|
1447 |
+
# upsert does new embedding, but if index already in memory, complains about size mismatch etc.
|
1448 |
+
client_collection = db._client.get_collection(name=db._collection.name,
|
1449 |
+
embedding_function=db._collection._embedding_function)
|
1450 |
+
client_collection.upsert(ids=db_get['ids'], metadatas=db_get['metadatas'], documents=db_get['documents'])
|
1451 |
+
changed_db = True
|
1452 |
+
print("Done updating db for new embedding: %s" % langchain_mode, flush=True)
|
1453 |
+
|
1454 |
+
return db, changed_db
|
1455 |
+
|
1456 |
+
|
1457 |
+
def get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
1458 |
+
hf_embedding_model, verbose=False, check_embedding=True):
|
1459 |
if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
|
1460 |
os.path.join(persist_directory, 'index')):
|
1461 |
+
if db is None:
|
1462 |
+
if verbose:
|
1463 |
+
print("DO Loading db: %s" % langchain_mode, flush=True)
|
1464 |
+
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
|
1465 |
+
from chromadb.config import Settings
|
1466 |
+
client_settings = Settings(anonymized_telemetry=False,
|
1467 |
+
chroma_db_impl="duckdb+parquet",
|
1468 |
+
persist_directory=persist_directory)
|
1469 |
+
db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
|
1470 |
+
collection_name=langchain_mode.replace(' ', '_'),
|
1471 |
+
client_settings=client_settings)
|
1472 |
+
if verbose:
|
1473 |
+
print("DONE Loading db: %s" % langchain_mode, flush=True)
|
1474 |
+
else:
|
1475 |
+
if verbose:
|
1476 |
+
print("USING already-loaded db: %s" % langchain_mode, flush=True)
|
1477 |
+
if check_embedding:
|
1478 |
+
db_trial, changed_db = check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model,
|
1479 |
+
langchain_mode)
|
1480 |
+
if changed_db:
|
1481 |
+
db = db_trial
|
1482 |
+
# only call persist if really changed db, else takes too long for large db
|
1483 |
+
if db is not None:
|
1484 |
+
db.persist()
|
1485 |
+
clear_embedding(db)
|
1486 |
+
save_embed(db, use_openai_embedding, hf_embedding_model)
|
1487 |
return db
|
1488 |
return None
|
1489 |
|
1490 |
|
1491 |
+
def clear_embedding(db):
|
1492 |
+
if db is None:
|
1493 |
+
return
|
1494 |
+
# don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
|
1495 |
+
db._embedding_function.client.cpu()
|
1496 |
+
clear_torch_cache()
|
1497 |
+
|
1498 |
+
|
1499 |
def make_db(**langchain_kwargs):
|
1500 |
func_names = list(inspect.signature(_make_db).parameters)
|
1501 |
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
|
|
|
1511 |
return _make_db(**langchain_kwargs)
|
1512 |
|
1513 |
|
1514 |
+
def save_embed(db, use_openai_embedding, hf_embedding_model):
|
1515 |
+
if db is not None:
|
1516 |
+
embed_info_file = os.path.join(db._persist_directory, 'embed_info')
|
1517 |
+
with open(embed_info_file, 'wb') as f:
|
1518 |
+
pickle.dump((use_openai_embedding, hf_embedding_model), f)
|
1519 |
+
return use_openai_embedding, hf_embedding_model
|
1520 |
+
|
1521 |
+
|
1522 |
+
def load_embed(db):
|
1523 |
+
embed_info_file = os.path.join(db._persist_directory, 'embed_info')
|
1524 |
+
if os.path.isfile(embed_info_file):
|
1525 |
+
with open(embed_info_file, 'rb') as f:
|
1526 |
+
use_openai_embedding, hf_embedding_model = pickle.load(f)
|
1527 |
+
else:
|
1528 |
+
# migration, assume defaults
|
1529 |
+
use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2"
|
1530 |
+
return use_openai_embedding, hf_embedding_model
|
1531 |
+
|
1532 |
+
|
1533 |
+
def get_persist_directory(langchain_mode):
|
1534 |
+
return 'db_dir_%s' % langchain_mode # single place, no special names for each case
|
1535 |
+
|
1536 |
+
|
1537 |
def _make_db(use_openai_embedding=False,
|
1538 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1539 |
+
first_para=False, text_limit=None,
|
1540 |
+
chunk=True, chunk_size=512,
|
1541 |
langchain_mode=None,
|
1542 |
user_path=None,
|
1543 |
db_type='faiss',
|
|
|
1545 |
db=None,
|
1546 |
n_jobs=-1,
|
1547 |
verbose=False):
|
1548 |
+
persist_directory = get_persist_directory(langchain_mode)
|
1549 |
+
# see if can get persistent chroma db
|
1550 |
+
db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
|
1551 |
+
hf_embedding_model, verbose=verbose)
|
1552 |
+
if db_trial is not None:
|
1553 |
+
db = db_trial
|
1554 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
1555 |
sources = []
|
1556 |
if not db and langchain_mode not in ['MyData'] or \
|
1557 |
user_path is not None and \
|
|
|
1576 |
sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
|
1577 |
print("Got new wiki", flush=True)
|
1578 |
if chunk:
|
1579 |
+
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1580 |
print("Chunked new wiki", flush=True)
|
1581 |
sources.extend(sources1)
|
1582 |
if langchain_mode in ['wiki', 'All', "'All'"]:
|
1583 |
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
|
1584 |
if chunk:
|
1585 |
+
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1586 |
sources.extend(sources1)
|
1587 |
if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
|
1588 |
# sources = get_github_docs("dagster-io", "dagster")
|
1589 |
sources1 = get_github_docs("h2oai", "h2ogpt")
|
1590 |
# FIXME: always chunk for now
|
1591 |
+
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1592 |
sources.extend(sources1)
|
1593 |
if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
|
1594 |
sources1 = get_dai_docs(from_hf=True)
|
1595 |
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
|
1596 |
+
sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
|
1597 |
sources.extend(sources1)
|
1598 |
if langchain_mode in ['All', 'UserData']:
|
1599 |
if user_path:
|
|
|
1607 |
existing_files = []
|
1608 |
existing_hash_ids = []
|
1609 |
# chunk internally for speed over multiple docs
|
1610 |
+
# FIXME: If first had old Hash=None and switch embeddings,
|
1611 |
+
# then re-embed, and then hit here and reload so have hash, and then re-embed.
|
1612 |
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
|
1613 |
existing_files=existing_files, existing_hash_ids=existing_hash_ids)
|
1614 |
new_metadata_sources = set([x.metadata['source'] for x in sources1])
|
|
|
1652 |
new_sources_metadata = [x.metadata for x in sources]
|
1653 |
elif user_path is not None and langchain_mode in ['UserData']:
|
1654 |
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
|
1655 |
+
db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
|
1656 |
+
use_openai_embedding=use_openai_embedding,
|
1657 |
+
hf_embedding_model=hf_embedding_model)
|
1658 |
print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
|
1659 |
else:
|
1660 |
new_sources_metadata = [x.metadata for x in sources]
|
|
|
1662 |
return db, len(new_sources_metadata), new_sources_metadata
|
1663 |
|
1664 |
|
1665 |
+
def get_metadatas(db):
|
1666 |
+
from langchain.vectorstores import FAISS
|
1667 |
+
if isinstance(db, FAISS):
|
1668 |
+
metadatas = [v.metadata for k, v in db.docstore._dict.items()]
|
1669 |
+
elif isinstance(db, Chroma):
|
1670 |
+
metadatas = get_documents(db)['metadatas']
|
1671 |
+
else:
|
1672 |
+
# FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
|
1673 |
+
# seems no way to get all metadata, so need to avoid this approach for weaviate
|
1674 |
+
metadatas = [x.metadata for x in db.similarity_search("", k=10000)]
|
1675 |
+
return metadatas
|
1676 |
+
|
1677 |
+
|
1678 |
+
def get_documents(db):
|
1679 |
+
if hasattr(db, '_persist_directory'):
|
1680 |
+
name_path = os.path.basename(db._persist_directory)
|
1681 |
+
base_path = 'locks'
|
1682 |
+
makedirs(base_path)
|
1683 |
+
with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
|
1684 |
+
# get segfaults and other errors when multiple threads access this
|
1685 |
+
return _get_documents(db)
|
1686 |
+
else:
|
1687 |
+
return _get_documents(db)
|
1688 |
+
|
1689 |
+
|
1690 |
+
def _get_documents(db):
|
1691 |
+
from langchain.vectorstores import FAISS
|
1692 |
+
if isinstance(db, FAISS):
|
1693 |
+
documents = [v for k, v in db.docstore._dict.items()]
|
1694 |
+
elif isinstance(db, Chroma):
|
1695 |
+
documents = db.get()
|
1696 |
+
else:
|
1697 |
+
# FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
|
1698 |
+
# seems no way to get all metadata, so need to avoid this approach for weaviate
|
1699 |
+
documents = [x for x in db.similarity_search("", k=10000)]
|
1700 |
+
return documents
|
1701 |
+
|
1702 |
+
|
1703 |
+
def get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
|
1704 |
+
if hasattr(db, '_persist_directory'):
|
1705 |
+
name_path = os.path.basename(db._persist_directory)
|
1706 |
+
base_path = 'locks'
|
1707 |
+
makedirs(base_path)
|
1708 |
+
with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
|
1709 |
+
return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
1710 |
+
else:
|
1711 |
+
return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
1712 |
+
|
1713 |
+
|
1714 |
+
def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
|
1715 |
+
from langchain.vectorstores import FAISS
|
1716 |
+
if isinstance(db, Chroma):
|
1717 |
+
db_get = db._collection.get(where=filter_kwargs.get('filter'))
|
1718 |
+
db_metadatas = db_get['metadatas']
|
1719 |
+
db_documents = db_get['documents']
|
1720 |
+
elif isinstance(db, FAISS):
|
1721 |
+
import itertools
|
1722 |
+
db_metadatas = get_metadatas(db)
|
1723 |
+
# FIXME: FAISS has no filter
|
1724 |
+
# slice dict first
|
1725 |
+
db_documents = list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values())
|
1726 |
+
else:
|
1727 |
+
db_metadatas = get_metadatas(db)
|
1728 |
+
db_documents = get_documents(db)
|
1729 |
+
return db_documents, db_metadatas
|
1730 |
+
|
1731 |
+
|
1732 |
def get_existing_files(db):
|
1733 |
+
metadatas = get_metadatas(db)
|
1734 |
+
metadata_sources = set([x['source'] for x in metadatas])
|
1735 |
return metadata_sources
|
1736 |
|
1737 |
|
1738 |
def get_existing_hash_ids(db):
|
1739 |
+
metadatas = get_metadatas(db)
|
1740 |
# assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
|
1741 |
+
metadata_hash_ids = {x['source']: x.get('hashid') for x in metadatas}
|
1742 |
return metadata_hash_ids
|
1743 |
|
1744 |
|
|
|
|
|
|
|
|
|
1745 |
def run_qa_db(**kwargs):
|
1746 |
func_names = list(inspect.signature(_run_qa_db).parameters)
|
1747 |
# hard-coded defaults
|
1748 |
kwargs['answer_with_sources'] = True
|
|
|
1749 |
kwargs['show_rank'] = False
|
1750 |
missing_kwargs = [x for x in func_names if x not in kwargs]
|
1751 |
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
1752 |
# only keep actual used
|
1753 |
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
|
1754 |
+
try:
|
1755 |
+
return _run_qa_db(**kwargs)
|
1756 |
+
finally:
|
1757 |
+
clear_torch_cache()
|
1758 |
|
1759 |
|
1760 |
def _run_qa_db(query=None,
|
1761 |
use_openai_model=False, use_openai_embedding=False,
|
1762 |
+
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1763 |
user_path=None,
|
1764 |
detect_user_path_changes_every_query=False,
|
1765 |
db_type='faiss',
|
1766 |
+
model_name=None, model=None, tokenizer=None, inference_server=None,
|
1767 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1768 |
stream_output=False,
|
1769 |
prompter=None,
|
1770 |
prompt_type=None,
|
1771 |
+
prompt_dict=None,
|
1772 |
answer_with_sources=True,
|
1773 |
cut_distanct=1.1,
|
1774 |
+
sanitize_bot_response=False,
|
1775 |
show_rank=False,
|
1776 |
load_db_if_exists=False,
|
1777 |
db=None,
|
1778 |
+
do_sample=False,
|
1779 |
temperature=0.1,
|
|
|
1780 |
top_k=40,
|
1781 |
top_p=0.7,
|
1782 |
+
num_beams=1,
|
1783 |
+
max_new_tokens=256,
|
1784 |
+
min_new_tokens=1,
|
1785 |
+
early_stopping=False,
|
1786 |
+
max_time=180,
|
1787 |
+
repetition_penalty=1.0,
|
1788 |
+
num_return_sequences=1,
|
1789 |
langchain_mode=None,
|
1790 |
+
document_choice=[DocumentChoices.All_Relevant.name],
|
1791 |
n_jobs=-1,
|
1792 |
verbose=False,
|
1793 |
+
cli=False,
|
1794 |
+
reverse_docs=True,
|
1795 |
+
lora_weights='',
|
1796 |
+
auto_reduce_chunks=True,
|
1797 |
+
max_chunks=100,
|
1798 |
+
):
|
1799 |
"""
|
1800 |
|
1801 |
:param query:
|
|
|
1814 |
:param answer_with_sources
|
1815 |
:return:
|
1816 |
"""
|
1817 |
+
if model is not None:
|
1818 |
+
assert model_name is not None # require so can make decisions
|
1819 |
assert query is not None
|
1820 |
assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
|
1821 |
if prompter is not None:
|
1822 |
prompt_type = prompter.prompt_type
|
1823 |
+
prompt_dict = prompter.prompt_dict
|
1824 |
if model is not None:
|
1825 |
assert prompt_type is not None
|
1826 |
+
if prompt_type == PromptType.custom.name:
|
1827 |
+
assert prompt_dict is not None # should at least be {} or ''
|
1828 |
+
else:
|
1829 |
+
prompt_dict = ''
|
1830 |
+
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
|
1831 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1832 |
+
model=model,
|
1833 |
+
tokenizer=tokenizer,
|
1834 |
+
inference_server=inference_server,
|
1835 |
stream_output=stream_output,
|
1836 |
+
do_sample=do_sample,
|
1837 |
temperature=temperature,
|
|
|
1838 |
top_k=top_k,
|
1839 |
top_p=top_p,
|
1840 |
+
num_beams=num_beams,
|
1841 |
+
max_new_tokens=max_new_tokens,
|
1842 |
+
min_new_tokens=min_new_tokens,
|
1843 |
+
early_stopping=early_stopping,
|
1844 |
+
max_time=max_time,
|
1845 |
+
repetition_penalty=repetition_penalty,
|
1846 |
+
num_return_sequences=num_return_sequences,
|
1847 |
prompt_type=prompt_type,
|
1848 |
+
prompt_dict=prompt_dict,
|
1849 |
prompter=prompter,
|
1850 |
+
sanitize_bot_response=sanitize_bot_response,
|
1851 |
verbose=verbose,
|
1852 |
)
|
1853 |
|
|
|
|
|
|
|
|
|
1854 |
use_context = False
|
1855 |
scores = []
|
1856 |
chain = None
|
1857 |
|
1858 |
+
if isinstance(document_choice, str):
|
1859 |
+
# support string as well
|
1860 |
+
document_choice = [document_choice]
|
1861 |
+
# get first DocumentChoices as command to use, ignore others
|
1862 |
+
doc_choices_set = set([x.name for x in list(DocumentChoices)])
|
1863 |
+
cmd = [x for x in document_choice if x in doc_choices_set]
|
1864 |
+
cmd = None if len(cmd) == 0 else cmd[0]
|
1865 |
+
# now have cmd, filter out for only docs
|
1866 |
+
document_choice = [x for x in document_choice if x not in doc_choices_set]
|
1867 |
+
|
1868 |
func_names = list(inspect.signature(get_similarity_chain).parameters)
|
1869 |
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
|
1870 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
1871 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
1872 |
docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
|
1873 |
+
if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
|
1874 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
1875 |
yield formatted_doc_chunks, ''
|
1876 |
return
|
|
|
1878 |
# can only return if HF type
|
1879 |
return
|
1880 |
|
1881 |
+
# context stuff similar to used in evaluate()
|
1882 |
+
import torch
|
1883 |
+
device, torch_dtype, context_class = get_device_dtype()
|
1884 |
+
with torch.no_grad():
|
1885 |
+
have_lora_weights = lora_weights not in [no_lora_str, '', None]
|
1886 |
+
context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
|
1887 |
+
with context_class_cast(device):
|
1888 |
+
if stream_output and streamer:
|
1889 |
+
answer = None
|
1890 |
+
import queue
|
1891 |
+
bucket = queue.Queue()
|
1892 |
+
thread = EThread(target=chain, streamer=streamer, bucket=bucket)
|
1893 |
+
thread.start()
|
1894 |
+
outputs = ""
|
1895 |
+
prompt = None # FIXME
|
1896 |
+
try:
|
1897 |
+
for new_text in streamer:
|
1898 |
+
# print("new_text: %s" % new_text, flush=True)
|
1899 |
+
if bucket.qsize() > 0 or thread.exc:
|
1900 |
+
thread.join()
|
1901 |
+
outputs += new_text
|
1902 |
+
if prompter: # and False: # FIXME: pipeline can already use prompter
|
1903 |
+
output1 = prompter.get_response(outputs, prompt=prompt,
|
1904 |
+
sanitize_bot_response=sanitize_bot_response)
|
1905 |
+
yield output1, ''
|
1906 |
+
else:
|
1907 |
+
yield outputs, ''
|
1908 |
+
except BaseException:
|
1909 |
+
# if any exception, raise that exception if was from thread, first
|
1910 |
+
if thread.exc:
|
1911 |
+
raise thread.exc
|
1912 |
+
raise
|
1913 |
+
finally:
|
1914 |
+
# in case no exception and didn't join with thread yet, then join
|
1915 |
+
if not thread.exc:
|
1916 |
+
answer = thread.join()
|
1917 |
+
# in case raise StopIteration or broke queue loop in streamer, but still have exception
|
1918 |
+
if thread.exc:
|
1919 |
+
raise thread.exc
|
1920 |
+
# FIXME: answer is not string outputs from streamer. How to get actual final output?
|
1921 |
+
# answer = outputs
|
1922 |
+
else:
|
1923 |
+
answer = chain()
|
1924 |
|
1925 |
if not use_context:
|
1926 |
ret = answer['output_text']
|
|
|
1934 |
|
1935 |
def get_similarity_chain(query=None,
|
1936 |
use_openai_model=False, use_openai_embedding=False,
|
1937 |
+
first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
|
1938 |
user_path=None,
|
1939 |
detect_user_path_changes_every_query=False,
|
1940 |
db_type='faiss',
|
1941 |
model_name=None,
|
1942 |
+
inference_server='',
|
1943 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1944 |
prompt_type=None,
|
1945 |
+
prompt_dict=None,
|
1946 |
cut_distanct=1.1,
|
1947 |
load_db_if_exists=False,
|
1948 |
db=None,
|
1949 |
langchain_mode=None,
|
1950 |
+
document_choice=[DocumentChoices.All_Relevant.name],
|
1951 |
n_jobs=-1,
|
1952 |
# beyond run_db_query:
|
1953 |
llm=None,
|
1954 |
+
tokenizer=None,
|
1955 |
verbose=False,
|
1956 |
+
cmd=None,
|
1957 |
+
reverse_docs=True,
|
1958 |
+
|
1959 |
+
# local
|
1960 |
+
auto_reduce_chunks=True,
|
1961 |
+
max_chunks=100,
|
1962 |
):
|
1963 |
# determine whether use of context out of docs is planned
|
1964 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
|
|
1970 |
use_context = True
|
1971 |
|
1972 |
# https://github.com/hwchase17/langchain/issues/1946
|
1973 |
+
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
1974 |
# Chroma collection MyData contains fewer than 4 elements.
|
1975 |
# type logger error
|
1976 |
+
if top_k_docs == -1:
|
1977 |
+
k_db = 1000 if db_type == 'chroma' else 100
|
1978 |
+
else:
|
1979 |
+
# top_k_docs=100 works ok too
|
1980 |
+
k_db = 1000 if db_type == 'chroma' else top_k_docs
|
1981 |
|
1982 |
# FIXME: For All just go over all dbs instead of a separate db for All
|
1983 |
if not detect_user_path_changes_every_query and db is not None:
|
|
|
1987 |
user_path = None
|
1988 |
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
|
1989 |
hf_embedding_model=hf_embedding_model,
|
1990 |
+
first_para=first_para, text_limit=text_limit,
|
1991 |
+
chunk=chunk,
|
1992 |
chunk_size=chunk_size,
|
1993 |
langchain_mode=langchain_mode,
|
1994 |
user_path=user_path,
|
|
|
1998 |
n_jobs=n_jobs,
|
1999 |
verbose=verbose)
|
2000 |
|
2001 |
+
if 'falcon' in model_name:
|
2002 |
+
extra = "According to only the information in the document sources provided within the context above, "
|
2003 |
+
prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
|
2004 |
+
elif inference_server in ['openai', 'openai_chat']:
|
2005 |
+
extra = "According to (primarily) the information in the document sources provided within context above, "
|
2006 |
+
prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
|
2007 |
+
else:
|
2008 |
+
extra = ""
|
2009 |
+
prefix = ""
|
2010 |
+
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
|
2011 |
+
template_if_no_docs = template = """%s{context}{question}""" % prefix
|
2012 |
+
else:
|
2013 |
+
template = """%s
|
2014 |
+
\"\"\"
|
2015 |
+
{context}
|
2016 |
+
\"\"\"
|
2017 |
+
%s{question}""" % (prefix, extra)
|
2018 |
+
template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
|
2019 |
+
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2020 |
+
use_template = True
|
2021 |
+
else:
|
2022 |
+
use_template = False
|
2023 |
+
|
2024 |
if db and use_context:
|
2025 |
+
if not isinstance(db, Chroma):
|
2026 |
+
# only chroma supports filtering
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2027 |
filter_kwargs = {}
|
2028 |
else:
|
2029 |
+
# if here then some cmd + documents selected or just documents selected
|
2030 |
if len(document_choice) >= 2:
|
2031 |
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
2032 |
filter_kwargs = dict(filter={"$or": or_filter})
|
2033 |
+
elif len(document_choice) == 1:
|
2034 |
+
# degenerate UX bug in chroma
|
2035 |
one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
|
2036 |
filter_kwargs = dict(filter=one_filter)
|
2037 |
else:
|
2038 |
+
# shouldn't reach
|
2039 |
filter_kwargs = {}
|
2040 |
+
if cmd == DocumentChoices.Just_LLM.name:
|
2041 |
+
docs = []
|
2042 |
+
scores = []
|
2043 |
+
elif cmd == DocumentChoices.Only_All_Sources.name:
|
2044 |
+
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
2045 |
+
# similar to langchain's chroma's _results_to_docs_and_scores
|
2046 |
+
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
2047 |
+
for result in zip(db_documents, db_metadatas)][:top_k_docs]
|
2048 |
+
docs = [x[0] for x in docs_with_score]
|
2049 |
+
scores = [x[1] for x in docs_with_score]
|
2050 |
+
else:
|
2051 |
+
if top_k_docs == -1 or auto_reduce_chunks:
|
2052 |
+
# docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2053 |
+
top_k_docs_tokenize = 100
|
2054 |
+
base_path = 'locks'
|
2055 |
+
makedirs(base_path)
|
2056 |
+
if hasattr(db, '_persist_directory'):
|
2057 |
+
name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
|
2058 |
+
else:
|
2059 |
+
name_path = "sim.lock"
|
2060 |
+
with filelock.FileLock(os.path.join(base_path, name_path)):
|
2061 |
+
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
|
2062 |
+
:top_k_docs_tokenize]
|
2063 |
+
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
|
2064 |
+
# more accurate
|
2065 |
+
tokens = [len(llm.pipeline.tokenizer(x[0].page_content)['input_ids']) for x in docs_with_score]
|
2066 |
+
template_tokens = len(llm.pipeline.tokenizer(template)['input_ids'])
|
2067 |
+
elif inference_server in ['openai', 'openai_chat'] or use_openai_model or db_type in ['faiss',
|
2068 |
+
'weaviate']:
|
2069 |
+
# use ticktoken for faiss since embedding called differently
|
2070 |
+
tokens = [llm.get_num_tokens(x[0].page_content) for x in docs_with_score]
|
2071 |
+
template_tokens = llm.get_num_tokens(template)
|
2072 |
+
elif isinstance(tokenizer, FakeTokenizer):
|
2073 |
+
tokens = [tokenizer.num_tokens_from_string(x[0].page_content) for x in docs_with_score]
|
2074 |
+
template_tokens = tokenizer.num_tokens_from_string(template)
|
2075 |
+
else:
|
2076 |
+
# in case model is not our pipeline with HF tokenizer
|
2077 |
+
tokens = [db._embedding_function.client.tokenize([x[0].page_content])['input_ids'].shape[1] for x in
|
2078 |
+
docs_with_score]
|
2079 |
+
template_tokens = db._embedding_function.client.tokenize([template])['input_ids'].shape[1]
|
2080 |
+
tokens_cumsum = np.cumsum(tokens)
|
2081 |
+
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'max_input_tokens'):
|
2082 |
+
max_input_tokens = llm.pipeline.max_input_tokens
|
2083 |
+
elif inference_server in ['openai']:
|
2084 |
+
max_tokens = llm.modelname_to_contextsize(model_name)
|
2085 |
+
# leave some room for 1 paragraph, even if min_new_tokens=0
|
2086 |
+
max_input_tokens = max_tokens - 256
|
2087 |
+
elif inference_server in ['openai_chat']:
|
2088 |
+
max_tokens = model_token_mapping[model_name]
|
2089 |
+
# leave some room for 1 paragraph, even if min_new_tokens=0
|
2090 |
+
max_input_tokens = max_tokens - 256
|
2091 |
+
elif isinstance(tokenizer, FakeTokenizer):
|
2092 |
+
max_input_tokens = tokenizer.model_max_length - 256
|
2093 |
+
else:
|
2094 |
+
# leave some room for 1 paragraph, even if min_new_tokens=0
|
2095 |
+
max_input_tokens = 2048 - 256
|
2096 |
+
max_input_tokens -= template_tokens
|
2097 |
+
# FIXME: Doesn't account for query, == context, or new lines between contexts
|
2098 |
+
where_res = np.where(tokens_cumsum < max_input_tokens)[0]
|
2099 |
+
if where_res.shape[0] == 0:
|
2100 |
+
# then no chunk can fit, still do first one
|
2101 |
+
top_k_docs_trial = 1
|
2102 |
+
else:
|
2103 |
+
top_k_docs_trial = 1 + where_res[-1]
|
2104 |
+
if 0 < top_k_docs_trial < max_chunks:
|
2105 |
+
# avoid craziness
|
2106 |
+
if top_k_docs == -1:
|
2107 |
+
top_k_docs = top_k_docs_trial
|
2108 |
+
else:
|
2109 |
+
top_k_docs = min(top_k_docs, top_k_docs_trial)
|
2110 |
+
if top_k_docs == -1:
|
2111 |
+
# if here, means 0 and just do best with 1 doc
|
2112 |
+
print("Unexpected large chunks and can't add to context, will add 1 anyways", flush=True)
|
2113 |
+
top_k_docs = 1
|
2114 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
2115 |
+
else:
|
2116 |
+
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2117 |
+
# put most relevant chunks closest to question,
|
2118 |
+
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
|
2119 |
+
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
|
2120 |
+
if reverse_docs:
|
2121 |
+
docs_with_score.reverse()
|
2122 |
+
# cut off so no high distance docs/sources considered
|
2123 |
+
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
2124 |
+
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
2125 |
+
if len(scores) > 0 and verbose:
|
2126 |
+
print("Distance: min: %s max: %s mean: %s median: %s" %
|
2127 |
+
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
2128 |
else:
|
2129 |
docs = []
|
2130 |
scores = []
|
|
|
2133 |
# if HF type and have no docs, can bail out
|
2134 |
return docs, None, [], False
|
2135 |
|
2136 |
+
if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
|
2137 |
# no LLM use
|
2138 |
return docs, None, [], False
|
2139 |
|
|
|
2153 |
if len(docs) == 0:
|
2154 |
# avoid context == in prompt then
|
2155 |
use_context = False
|
2156 |
+
template = template_if_no_docs
|
2157 |
|
2158 |
+
if use_template:
|
2159 |
# instruct-like, rather than few-shot prompt_type='plain' as default
|
2160 |
# but then sources confuse the model with how inserted among rest of text, so avoid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2161 |
prompt = PromptTemplate(
|
2162 |
# input_variables=["summaries", "question"],
|
2163 |
input_variables=["context", "question"],
|
|
|
2217 |
return ret, extra
|
2218 |
|
2219 |
|
2220 |
+
def clean_doc(docs1):
|
2221 |
+
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
2222 |
+
docs1 = [docs1]
|
2223 |
+
for doci, doc in enumerate(docs1):
|
2224 |
+
docs1[doci].page_content = '\n'.join([x.strip() for x in doc.page_content.split("\n") if x.strip()])
|
2225 |
+
return docs1
|
2226 |
+
|
2227 |
+
|
2228 |
+
def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
2229 |
+
if not chunk:
|
2230 |
+
return sources
|
2231 |
+
if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
|
2232 |
+
# if just one document
|
2233 |
+
sources = [sources]
|
2234 |
+
if language and False:
|
2235 |
+
# Bug in langchain, keep separator=True not working
|
2236 |
+
# https://github.com/hwchase17/langchain/issues/2836
|
2237 |
+
# so avoid this for now
|
2238 |
+
keep_separator = True
|
2239 |
+
separators = RecursiveCharacterTextSplitter.get_separators_for_language(language)
|
2240 |
+
else:
|
2241 |
+
separators = ["\n\n", "\n", " ", ""]
|
2242 |
+
keep_separator = False
|
2243 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
|
2244 |
+
separators=separators)
|
2245 |
+
source_chunks = splitter.split_documents(sources)
|
2246 |
return source_chunks
|
2247 |
|
2248 |
|
|
|
2253 |
path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
|
2254 |
import zipfile
|
2255 |
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
|
2256 |
+
persist_directory = os.path.dirname(zip_ref.namelist()[0])
|
2257 |
+
remove(persist_directory)
|
2258 |
zip_ref.extractall(dest)
|
2259 |
return path_to_zip_file
|
2260 |
|
|
|
2283 |
assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
|
2284 |
|
2285 |
|
2286 |
+
def _create_local_weaviate_client():
|
2287 |
+
WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
|
2288 |
+
WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
|
2289 |
+
WEAVIATE_PASSWORD = os.getenv('WEAVIATE_PASSWORD')
|
2290 |
+
WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access")
|
2291 |
+
|
2292 |
+
resource_owner_config = None
|
2293 |
+
try:
|
2294 |
+
import weaviate
|
2295 |
+
if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None:
|
2296 |
+
resource_owner_config = weaviate.AuthClientPassword(
|
2297 |
+
username=WEAVIATE_USERNAME,
|
2298 |
+
password=WEAVIATE_PASSWORD,
|
2299 |
+
scope=WEAVIATE_SCOPE
|
2300 |
+
)
|
2301 |
+
|
2302 |
+
client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config)
|
2303 |
+
return client
|
2304 |
+
except Exception as e:
|
2305 |
+
print(f"Failed to create Weaviate client: {e}")
|
2306 |
+
return None
|
2307 |
+
|
2308 |
+
|
2309 |
if __name__ == '__main__':
|
2310 |
pass
|
gradio_runner.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
gradio_themes.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3 |
from typing import Iterable
|
4 |
|
5 |
from gradio.themes.soft import Soft
|
6 |
-
from gradio.themes import Color
|
7 |
from gradio.themes.utils import colors, sizes, fonts
|
8 |
|
9 |
h2o_yellow = Color(
|
@@ -36,6 +36,42 @@ h2o_gray = Color(
|
|
36 |
)
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
class H2oTheme(Soft):
|
40 |
def __init__(
|
41 |
self,
|
@@ -158,19 +194,23 @@ h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/
|
|
158 |
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
|
159 |
|
160 |
|
161 |
-
def get_h2o_title(title):
|
162 |
-
|
|
|
|
|
|
|
|
|
163 |
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
|
164 |
<h1 style="line-height:60px">{title}</h1>
|
165 |
</div>
|
166 |
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
167 |
-
<img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png
|
168 |
</div>
|
169 |
"""
|
170 |
|
171 |
|
172 |
-
def get_simple_title(title):
|
173 |
-
return f"""<h1 align="center"> {title}</h1>"""
|
174 |
|
175 |
|
176 |
def get_dark_js():
|
|
|
3 |
from typing import Iterable
|
4 |
|
5 |
from gradio.themes.soft import Soft
|
6 |
+
from gradio.themes import Color, Size
|
7 |
from gradio.themes.utils import colors, sizes, fonts
|
8 |
|
9 |
h2o_yellow = Color(
|
|
|
36 |
)
|
37 |
|
38 |
|
39 |
+
text_xsm = Size(
|
40 |
+
name="text_xsm",
|
41 |
+
xxs="4px",
|
42 |
+
xs="5px",
|
43 |
+
sm="6px",
|
44 |
+
md="7px",
|
45 |
+
lg="8px",
|
46 |
+
xl="10px",
|
47 |
+
xxl="12px",
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
spacing_xsm = Size(
|
52 |
+
name="spacing_xsm",
|
53 |
+
xxs="1px",
|
54 |
+
xs="1px",
|
55 |
+
sm="1px",
|
56 |
+
md="2px",
|
57 |
+
lg="3px",
|
58 |
+
xl="5px",
|
59 |
+
xxl="7px",
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
radius_xsm = Size(
|
64 |
+
name="radius_xsm",
|
65 |
+
xxs="1px",
|
66 |
+
xs="1px",
|
67 |
+
sm="1px",
|
68 |
+
md="2px",
|
69 |
+
lg="3px",
|
70 |
+
xl="5px",
|
71 |
+
xxl="7px",
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
class H2oTheme(Soft):
|
76 |
def __init__(
|
77 |
self,
|
|
|
194 |
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
|
195 |
|
196 |
|
197 |
+
def get_h2o_title(title, description):
|
198 |
+
# NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
|
199 |
+
return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
|
200 |
+
{description}
|
201 |
+
</div>
|
202 |
+
<div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
|
203 |
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
|
204 |
<h1 style="line-height:60px">{title}</h1>
|
205 |
</div>
|
206 |
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
207 |
+
<img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
|
208 |
</div>
|
209 |
"""
|
210 |
|
211 |
|
212 |
+
def get_simple_title(title, description):
|
213 |
+
return f"""{description}<h1 align="center"> {title}</h1>"""
|
214 |
|
215 |
|
216 |
def get_dark_js():
|
gradio_utils/__pycache__/css.cpython-310.pyc
ADDED
Binary file (1.53 kB). View file
|
|
gradio_utils/__pycache__/grclient.cpython-310.pyc
ADDED
Binary file (2.69 kB). View file
|
|
gradio_utils/__pycache__/prompt_form.cpython-310.pyc
ADDED
Binary file (3.59 kB). View file
|
|
gradio_utils/css.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_css(kwargs) -> str:
|
2 |
+
if kwargs['h2ocolors']:
|
3 |
+
css_code = """footer {visibility: hidden;}
|
4 |
+
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
5 |
+
body.dark{background:linear-gradient(#000000,#0d0d0d);}
|
6 |
+
"""
|
7 |
+
else:
|
8 |
+
css_code = """footer {visibility: hidden}"""
|
9 |
+
|
10 |
+
css_code += make_css_base()
|
11 |
+
return css_code
|
12 |
+
|
13 |
+
|
14 |
+
def make_css_base() -> str:
|
15 |
+
return """
|
16 |
+
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
17 |
+
|
18 |
+
body.dark{#warning {background-color: #555555};}
|
19 |
+
|
20 |
+
#small_btn {
|
21 |
+
margin: 0.6em 0em 0.55em 0;
|
22 |
+
max-width: 20em;
|
23 |
+
min-width: 5em !important;
|
24 |
+
height: 5em;
|
25 |
+
font-size: 14px !important;
|
26 |
+
}
|
27 |
+
|
28 |
+
#prompt-form {
|
29 |
+
border: 1px solid var(--primary-500) !important;
|
30 |
+
}
|
31 |
+
|
32 |
+
#prompt-form.block {
|
33 |
+
border-radius: var(--block-radius) !important;
|
34 |
+
}
|
35 |
+
|
36 |
+
#prompt-form textarea {
|
37 |
+
border: 1px solid rgb(209, 213, 219);
|
38 |
+
}
|
39 |
+
|
40 |
+
#prompt-form label > div {
|
41 |
+
margin-top: 4px;
|
42 |
+
}
|
43 |
+
|
44 |
+
button.primary:hover {
|
45 |
+
background-color: var(--primary-600) !important;
|
46 |
+
transition: .2s;
|
47 |
+
}
|
48 |
+
|
49 |
+
#prompt-form-area {
|
50 |
+
margin-bottom: 2.5rem;
|
51 |
+
}
|
52 |
+
.chatsmall chatbot {font-size: 10px !important}
|
53 |
+
"""
|
gradio_utils/grclient.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from typing import Callable
|
3 |
+
import os
|
4 |
+
|
5 |
+
from gradio_client.client import Job
|
6 |
+
|
7 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
8 |
+
|
9 |
+
from gradio_client import Client
|
10 |
+
|
11 |
+
|
12 |
+
class GradioClient(Client):
|
13 |
+
"""
|
14 |
+
Parent class of gradio client
|
15 |
+
To handle automatically refreshing client if detect gradio server changed
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, *args, **kwargs):
|
19 |
+
self.args = args
|
20 |
+
self.kwargs = kwargs
|
21 |
+
super().__init__(*args, **kwargs)
|
22 |
+
self.server_hash = self.get_server_hash()
|
23 |
+
|
24 |
+
def get_server_hash(self):
|
25 |
+
"""
|
26 |
+
Get server hash using super without any refresh action triggered
|
27 |
+
Returns: git hash of gradio server
|
28 |
+
"""
|
29 |
+
return super().submit(api_name='/system_hash').result()
|
30 |
+
|
31 |
+
def refresh_client_if_should(self):
|
32 |
+
# get current hash in order to update api_name -> fn_index map in case gradio server changed
|
33 |
+
# FIXME: Could add cli api as hash
|
34 |
+
server_hash = self.get_server_hash()
|
35 |
+
if self.server_hash != server_hash:
|
36 |
+
self.refresh_client()
|
37 |
+
self.server_hash = server_hash
|
38 |
+
else:
|
39 |
+
self.reset_session()
|
40 |
+
|
41 |
+
def refresh_client(self):
|
42 |
+
"""
|
43 |
+
Ensure every client call is independent
|
44 |
+
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
|
45 |
+
Returns:
|
46 |
+
"""
|
47 |
+
# need session hash to be new every time, to avoid "generator already executing"
|
48 |
+
self.reset_session()
|
49 |
+
|
50 |
+
client = Client(*self.args, **self.kwargs)
|
51 |
+
for k, v in client.__dict__.items():
|
52 |
+
setattr(self, k, v)
|
53 |
+
|
54 |
+
def submit(
|
55 |
+
self,
|
56 |
+
*args,
|
57 |
+
api_name: str | None = None,
|
58 |
+
fn_index: int | None = None,
|
59 |
+
result_callbacks: Callable | list[Callable] | None = None,
|
60 |
+
) -> Job:
|
61 |
+
# Note predict calls submit
|
62 |
+
try:
|
63 |
+
self.refresh_client_if_should()
|
64 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
65 |
+
except Exception as e:
|
66 |
+
print("Hit e=%s" % str(e), flush=True)
|
67 |
+
# force reconfig in case only that
|
68 |
+
self.refresh_client()
|
69 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
70 |
+
|
71 |
+
# see if immediately failed
|
72 |
+
e = job.future._exception
|
73 |
+
if e is not None:
|
74 |
+
print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True)
|
75 |
+
# force reconfig in case only that
|
76 |
+
self.refresh_client()
|
77 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
78 |
+
e2 = job.future._exception
|
79 |
+
if e2 is not None:
|
80 |
+
print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True)
|
81 |
+
|
82 |
+
return job
|
gradio_utils/prompt_form.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def make_chatbots(output_label0, output_label0_model2, **kwargs):
|
8 |
+
text_outputs = []
|
9 |
+
chat_kwargs = []
|
10 |
+
for model_state_lock in kwargs['model_states']:
|
11 |
+
if os.environ.get('DEBUG_MODEL_LOCK'):
|
12 |
+
model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
|
13 |
+
else:
|
14 |
+
model_name = model_state_lock["base_model"]
|
15 |
+
output_label = f'h2oGPT [{model_name}]'
|
16 |
+
min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
|
17 |
+
chat_kwargs.append(dict(label=output_label, visible=kwargs['model_lock'], elem_classes='chatsmall',
|
18 |
+
height=kwargs['height'] or 400, min_width=min_width))
|
19 |
+
|
20 |
+
if kwargs['model_lock_columns'] == -1:
|
21 |
+
kwargs['model_lock_columns'] = len(kwargs['model_states'])
|
22 |
+
if kwargs['model_lock_columns'] is None:
|
23 |
+
kwargs['model_lock_columns'] = 3
|
24 |
+
|
25 |
+
ncols = kwargs['model_lock_columns']
|
26 |
+
if kwargs['model_states'] == 0:
|
27 |
+
nrows = 0
|
28 |
+
else:
|
29 |
+
nrows = math.ceil(len(kwargs['model_states']) / kwargs['model_lock_columns'])
|
30 |
+
|
31 |
+
if kwargs['model_lock_columns'] == 0:
|
32 |
+
# not using model_lock
|
33 |
+
pass
|
34 |
+
elif nrows <= 1:
|
35 |
+
with gr.Row():
|
36 |
+
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
37 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
38 |
+
elif nrows == kwargs['model_states']:
|
39 |
+
with gr.Row():
|
40 |
+
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
41 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
42 |
+
elif nrows == 2:
|
43 |
+
with gr.Row():
|
44 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
45 |
+
if mii >= len(kwargs['model_states']) / 2:
|
46 |
+
continue
|
47 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
48 |
+
with gr.Row():
|
49 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
50 |
+
if mii < len(kwargs['model_states']) / 2:
|
51 |
+
continue
|
52 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
53 |
+
elif nrows == 3:
|
54 |
+
with gr.Row():
|
55 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
56 |
+
if mii >= 1 * len(kwargs['model_states']) / 3:
|
57 |
+
continue
|
58 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
59 |
+
with gr.Row():
|
60 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
61 |
+
if mii < 1 * len(kwargs['model_states']) / 3 or mii >= 2 * len(kwargs['model_states']) / 3:
|
62 |
+
continue
|
63 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
64 |
+
with gr.Row():
|
65 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
66 |
+
if mii < 2 * len(kwargs['model_states']) / 3:
|
67 |
+
continue
|
68 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
69 |
+
elif nrows >= 4:
|
70 |
+
with gr.Row():
|
71 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
72 |
+
if mii >= 1 * len(kwargs['model_states']) / 4:
|
73 |
+
continue
|
74 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
75 |
+
with gr.Row():
|
76 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
77 |
+
if mii < 1 * len(kwargs['model_states']) / 4 or mii >= 2 * len(kwargs['model_states']) / 4:
|
78 |
+
continue
|
79 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
80 |
+
with gr.Row():
|
81 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
82 |
+
if mii < 2 * len(kwargs['model_states']) / 4 or mii >= 3 * len(kwargs['model_states']) / 4:
|
83 |
+
continue
|
84 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
85 |
+
with gr.Row():
|
86 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
87 |
+
if mii < 3 * len(kwargs['model_states']) / 4:
|
88 |
+
continue
|
89 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
|
93 |
+
text_output2 = gr.Chatbot(label=output_label0_model2,
|
94 |
+
visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
|
95 |
+
return text_output, text_output2, text_outputs
|
96 |
+
|
97 |
+
|
98 |
+
def make_prompt_form(kwargs):
|
99 |
+
if kwargs['input_lines'] > 1:
|
100 |
+
instruction_label = "Shift-Enter to Submit, Enter for more lines"
|
101 |
+
else:
|
102 |
+
instruction_label = "Enter to Submit, Shift-Enter for more lines"
|
103 |
+
|
104 |
+
with gr.Row():#elem_id='prompt-form-area'):
|
105 |
+
with gr.Column(scale=50):
|
106 |
+
instruction = gr.Textbox(
|
107 |
+
lines=kwargs['input_lines'],
|
108 |
+
label='Ask anything',
|
109 |
+
placeholder=instruction_label,
|
110 |
+
info=None,
|
111 |
+
elem_id='prompt-form',
|
112 |
+
container=True,
|
113 |
+
)
|
114 |
+
with gr.Row():
|
115 |
+
submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
|
116 |
+
stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')
|
117 |
+
|
118 |
+
return instruction, submit, stop_btn
|
h2oai_pipeline.py
CHANGED
@@ -1,14 +1,17 @@
|
|
|
|
|
|
1 |
from transformers import TextGenerationPipeline
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
4 |
from stopping import get_stopping
|
5 |
-
from prompter import Prompter
|
6 |
|
7 |
|
8 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
9 |
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
10 |
-
sanitize_bot_response=
|
11 |
-
use_prompter=True, prompter=None,
|
|
|
12 |
max_input_tokens=2048 - 256, **kwargs):
|
13 |
"""
|
14 |
HF-like pipeline, but handle instruction prompting and stopping (for some models)
|
@@ -21,6 +24,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
21 |
:param prompter: prompter, can pass if have already
|
22 |
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
|
23 |
If use_prompter, then will make prompter and use it.
|
|
|
24 |
:param max_input_tokens:
|
25 |
:param kwargs:
|
26 |
"""
|
@@ -28,12 +32,14 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
28 |
self.prompt_text = None
|
29 |
self.use_prompter = use_prompter
|
30 |
self.prompt_type = prompt_type
|
|
|
31 |
self.prompter = prompter
|
32 |
if self.use_prompter:
|
33 |
if self.prompter is not None:
|
34 |
assert self.prompter.prompt_type is not None
|
35 |
else:
|
36 |
-
self.prompter = Prompter(self.prompt_type, debug=debug, chat=chat,
|
|
|
37 |
self.human = self.prompter.humanstr
|
38 |
self.bot = self.prompter.botstr
|
39 |
self.can_stop = True
|
@@ -45,14 +51,75 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
45 |
self.sanitize_bot_response = sanitize_bot_response
|
46 |
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
|
|
|
|
49 |
data_point = dict(context='', instruction=prompt_text, input='')
|
50 |
if self.prompter is not None:
|
51 |
prompt_text = self.prompter.generate_prompt(data_point)
|
52 |
self.prompt_text = prompt_text
|
53 |
if handle_long_generation is None:
|
54 |
# forces truncation of inputs to avoid critical failure
|
55 |
-
handle_long_generation =
|
56 |
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
57 |
**generate_kwargs)
|
58 |
|
@@ -65,7 +132,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
65 |
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
66 |
sanitize_bot_response=self.sanitize_bot_response)
|
67 |
elif self.bot and self.human:
|
68 |
-
outputs = rec['generated_text'].split(self.bot)[1].
|
69 |
else:
|
70 |
outputs = rec['generated_text']
|
71 |
rec['generated_text'] = outputs
|
@@ -73,8 +140,10 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
73 |
|
74 |
def _forward(self, model_inputs, **generate_kwargs):
|
75 |
if self.can_stop:
|
76 |
-
stopping_criteria = get_stopping(self.prompt_type, self.
|
77 |
-
|
|
|
|
|
78 |
generate_kwargs['stopping_criteria'] = stopping_criteria
|
79 |
# return super()._forward(model_inputs, **generate_kwargs)
|
80 |
return self.__forward(model_inputs, **generate_kwargs)
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
from transformers import TextGenerationPipeline
|
4 |
from transformers.pipelines.text_generation import ReturnType
|
5 |
|
6 |
from stopping import get_stopping
|
7 |
+
from prompter import Prompter, PromptType
|
8 |
|
9 |
|
10 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
11 |
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
+
sanitize_bot_response=False,
|
13 |
+
use_prompter=True, prompter=None,
|
14 |
+
prompt_type=None, prompt_dict=None,
|
15 |
max_input_tokens=2048 - 256, **kwargs):
|
16 |
"""
|
17 |
HF-like pipeline, but handle instruction prompting and stopping (for some models)
|
|
|
24 |
:param prompter: prompter, can pass if have already
|
25 |
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
|
26 |
If use_prompter, then will make prompter and use it.
|
27 |
+
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
|
28 |
:param max_input_tokens:
|
29 |
:param kwargs:
|
30 |
"""
|
|
|
32 |
self.prompt_text = None
|
33 |
self.use_prompter = use_prompter
|
34 |
self.prompt_type = prompt_type
|
35 |
+
self.prompt_dict = prompt_dict
|
36 |
self.prompter = prompter
|
37 |
if self.use_prompter:
|
38 |
if self.prompter is not None:
|
39 |
assert self.prompter.prompt_type is not None
|
40 |
else:
|
41 |
+
self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
|
42 |
+
stream_output=stream_output)
|
43 |
self.human = self.prompter.humanstr
|
44 |
self.bot = self.prompter.botstr
|
45 |
self.can_stop = True
|
|
|
51 |
self.sanitize_bot_response = sanitize_bot_response
|
52 |
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
|
53 |
|
54 |
+
@staticmethod
|
55 |
+
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
|
56 |
+
verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
|
57 |
+
|
58 |
+
if hasattr(tokenizer, 'model_max_length'):
|
59 |
+
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
|
60 |
+
model_max_length = tokenizer.model_max_length
|
61 |
+
if max_prompt_length is not None:
|
62 |
+
model_max_length = min(model_max_length, max_prompt_length)
|
63 |
+
# cut at some upper likely limit to avoid excessive tokenization etc
|
64 |
+
# upper bound of 10 chars/token, e.g. special chars sometimes are long
|
65 |
+
if len(prompt_text) > model_max_length * 10:
|
66 |
+
len0 = len(prompt_text)
|
67 |
+
prompt_text = prompt_text[-model_max_length * 10:]
|
68 |
+
if verbose:
|
69 |
+
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
|
70 |
+
else:
|
71 |
+
# unknown
|
72 |
+
model_max_length = None
|
73 |
+
|
74 |
+
num_prompt_tokens = None
|
75 |
+
if model_max_length is not None:
|
76 |
+
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
77 |
+
# For https://github.com/h2oai/h2ogpt/issues/192
|
78 |
+
for trial in range(0, 3):
|
79 |
+
prompt_tokens = tokenizer(prompt_text)['input_ids']
|
80 |
+
num_prompt_tokens = len(prompt_tokens)
|
81 |
+
if num_prompt_tokens > model_max_length:
|
82 |
+
# conservative by using int()
|
83 |
+
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
|
84 |
+
# keep tail, where question is if using langchain
|
85 |
+
prompt_text = prompt_text[-model_max_length * chars_per_token:]
|
86 |
+
if verbose:
|
87 |
+
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
|
88 |
+
num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
|
89 |
+
else:
|
90 |
+
if verbose:
|
91 |
+
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
|
92 |
+
break
|
93 |
+
|
94 |
+
# Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
|
95 |
+
if False:
|
96 |
+
# if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
|
97 |
+
#
|
98 |
+
assert num_prompt_tokens is not None
|
99 |
+
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
|
100 |
+
# then give room for prompt
|
101 |
+
fudge = 20
|
102 |
+
else:
|
103 |
+
fudge = 0
|
104 |
+
max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
|
105 |
+
model_max_length - (num_prompt_tokens + fudge)))
|
106 |
+
if max_new_tokens < generate_kwargs['max_new_tokens']:
|
107 |
+
if verbose:
|
108 |
+
print("Reduced max_new_tokens from %s -> %s" % (
|
109 |
+
generate_kwargs['max_new_tokens'], max_new_tokens))
|
110 |
+
generate_kwargs['max_new_tokens'] = max_new_tokens
|
111 |
+
return prompt_text, num_prompt_tokens
|
112 |
+
|
113 |
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
114 |
+
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
115 |
+
|
116 |
data_point = dict(context='', instruction=prompt_text, input='')
|
117 |
if self.prompter is not None:
|
118 |
prompt_text = self.prompter.generate_prompt(data_point)
|
119 |
self.prompt_text = prompt_text
|
120 |
if handle_long_generation is None:
|
121 |
# forces truncation of inputs to avoid critical failure
|
122 |
+
handle_long_generation = None # disable with new approaches
|
123 |
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
124 |
**generate_kwargs)
|
125 |
|
|
|
132 |
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
133 |
sanitize_bot_response=self.sanitize_bot_response)
|
134 |
elif self.bot and self.human:
|
135 |
+
outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
|
136 |
else:
|
137 |
outputs = rec['generated_text']
|
138 |
rec['generated_text'] = outputs
|
|
|
140 |
|
141 |
def _forward(self, model_inputs, **generate_kwargs):
|
142 |
if self.can_stop:
|
143 |
+
stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
|
144 |
+
self.tokenizer, self.device,
|
145 |
+
human=self.human, bot=self.bot,
|
146 |
+
model_max_length=self.tokenizer.model_max_length)
|
147 |
generate_kwargs['stopping_criteria'] = stopping_criteria
|
148 |
# return super()._forward(model_inputs, **generate_kwargs)
|
149 |
return self.__forward(model_inputs, **generate_kwargs)
|
loaders.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
-
def get_loaders(
|
2 |
# NOTE: Some models need specific new prompt_type
|
3 |
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
|
|
|
|
4 |
if llama_type:
|
5 |
from transformers import LlamaForCausalLM, LlamaTokenizer
|
6 |
model_loader = LlamaForCausalLM
|
@@ -39,7 +41,8 @@ def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resu
|
|
39 |
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
40 |
local_files_only=local_files_only,
|
41 |
resume_download=resume_download,
|
42 |
-
use_auth_token=use_auth_token
|
|
|
43 |
|
44 |
tokenizer.pad_token_id = 0 # different from the eos token
|
45 |
# when generating, we will use the logits of right-most token to predict the next token
|
|
|
1 |
+
def get_loaders(model_name, reward_type, llama_type=None):
|
2 |
# NOTE: Some models need specific new prompt_type
|
3 |
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
4 |
+
if llama_type is None:
|
5 |
+
llama_type = "llama" in model_name.lower()
|
6 |
if llama_type:
|
7 |
from transformers import LlamaForCausalLM, LlamaTokenizer
|
8 |
model_loader = LlamaForCausalLM
|
|
|
41 |
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
42 |
local_files_only=local_files_only,
|
43 |
resume_download=resume_download,
|
44 |
+
use_auth_token=use_auth_token,
|
45 |
+
padding_side='left')
|
46 |
|
47 |
tokenizer.pad_token_id = 0 # different from the eos token
|
48 |
# when generating, we will use the logits of right-most token to predict the next token
|
prompter.py
CHANGED
@@ -1,30 +1,10 @@
|
|
|
|
|
|
1 |
import time
|
2 |
-
from
|
3 |
|
4 |
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
5 |
|
6 |
-
|
7 |
-
class PromptType(Enum):
|
8 |
-
plain = 0
|
9 |
-
instruct = 1
|
10 |
-
quality = 2
|
11 |
-
human_bot = 3
|
12 |
-
dai_faq = 4
|
13 |
-
summarize = 5
|
14 |
-
simple_instruct = 6
|
15 |
-
instruct_vicuna = 7
|
16 |
-
instruct_with_end = 8
|
17 |
-
human_bot_orig = 9
|
18 |
-
prompt_answer = 10
|
19 |
-
open_assistant = 11
|
20 |
-
wizard_lm = 12
|
21 |
-
wizard_mega = 13
|
22 |
-
instruct_vicuna2 = 14
|
23 |
-
instruct_vicuna3 = 15
|
24 |
-
wizard2 = 16
|
25 |
-
wizard3 = 17
|
26 |
-
|
27 |
-
|
28 |
prompt_type_to_model_name = {
|
29 |
'plain': [
|
30 |
'EleutherAI/gpt-j-6B',
|
@@ -45,17 +25,29 @@ prompt_type_to_model_name = {
|
|
45 |
'mosaicml/mpt-7b-storywriter',
|
46 |
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
47 |
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
48 |
-
'
|
49 |
-
'llama', # plain, or need to choose prompt_type for given TheBloke model
|
50 |
-
'gpt4all_llama', # internally handles prompting
|
51 |
],
|
|
|
52 |
'prompt_answer': [
|
53 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
54 |
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
55 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
57 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
58 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
|
|
|
|
59 |
],
|
60 |
'instruct': [],
|
61 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
@@ -68,6 +60,7 @@ prompt_type_to_model_name = {
|
|
68 |
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
69 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
70 |
'h2oai/h2ogpt-research-oasst1-512-30b',
|
|
|
71 |
'h2oai/h2ogpt-oasst1-falcon-40b',
|
72 |
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
73 |
],
|
@@ -79,7 +72,17 @@ prompt_type_to_model_name = {
|
|
79 |
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
80 |
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
81 |
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
|
|
|
|
|
|
|
|
|
|
82 |
}
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
85 |
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
@@ -93,20 +96,53 @@ for p in PromptType:
|
|
93 |
prompt_types.extend([p.name, p.value, str(p.value)])
|
94 |
|
95 |
|
96 |
-
def get_prompt(prompt_type, chat, context, reduced):
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
chat_sep = ''
|
|
|
102 |
humanstr = ''
|
103 |
botstr = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
elif prompt_type == 'simple_instruct':
|
105 |
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
106 |
terminate_response = []
|
107 |
-
chat_sep = '\n'
|
108 |
-
humanstr =
|
109 |
-
botstr =
|
110 |
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
111 |
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
112 |
str(PromptType.instruct_with_end.value),
|
@@ -132,7 +168,7 @@ def get_prompt(prompt_type, chat, context, reduced):
|
|
132 |
terminate_response = ['### End']
|
133 |
else:
|
134 |
terminate_response = None
|
135 |
-
chat_sep = '\n'
|
136 |
humanstr = PreInstruct
|
137 |
botstr = PreResponse
|
138 |
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
@@ -154,7 +190,7 @@ def get_prompt(prompt_type, chat, context, reduced):
|
|
154 |
### Response:
|
155 |
"""
|
156 |
terminate_response = None
|
157 |
-
chat_sep = '\n'
|
158 |
humanstr = PreInstruct # first thing human says
|
159 |
botstr = PreResponse # first thing bot says
|
160 |
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
@@ -176,14 +212,14 @@ Current Time: {}
|
|
176 |
|
177 |
"""
|
178 |
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
179 |
-
start =
|
180 |
-
promptB = promptA = '%s%s
|
181 |
|
182 |
-
PreInstruct =
|
183 |
|
184 |
PreInput = None
|
185 |
|
186 |
-
if
|
187 |
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
188 |
PreResponse = bot + ' '
|
189 |
else:
|
@@ -191,10 +227,11 @@ Current Time: {}
|
|
191 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
192 |
PreResponse = bot
|
193 |
|
194 |
-
terminate_response = [
|
195 |
-
chat_sep = '\n'
|
196 |
humanstr = human # tag before human talks
|
197 |
botstr = bot # tag before bot talks
|
|
|
198 |
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
199 |
PromptType.dai_faq.name]:
|
200 |
promptA = ''
|
@@ -210,7 +247,7 @@ Current Time: {}
|
|
210 |
### Driverless AI documentation answer:
|
211 |
"""
|
212 |
terminate_response = ['\n\n']
|
213 |
-
chat_sep = terminate_response
|
214 |
humanstr = PreInstruct
|
215 |
botstr = PreResponse
|
216 |
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
@@ -219,7 +256,7 @@ Current Time: {}
|
|
219 |
PreInstruct = '## Main Text\n\n'
|
220 |
PreResponse = '\n\n## Summary\n\n'
|
221 |
terminate_response = None
|
222 |
-
chat_sep = '\n'
|
223 |
humanstr = PreInstruct
|
224 |
botstr = PreResponse
|
225 |
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
@@ -239,7 +276,7 @@ Current Time: {}
|
|
239 |
"""
|
240 |
terminate_response = [
|
241 |
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
242 |
-
chat_sep = '\n'
|
243 |
humanstr = PreInstruct
|
244 |
botstr = PreResponse
|
245 |
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
@@ -247,33 +284,50 @@ Current Time: {}
|
|
247 |
preprompt = ''
|
248 |
prompt_tokens = "<|prompt|>"
|
249 |
answer_tokens = "<|answer|>"
|
250 |
-
start =
|
251 |
promptB = promptA = '%s%s' % (preprompt, start)
|
252 |
-
PreInstruct =
|
253 |
PreInput = None
|
254 |
PreResponse = answer_tokens
|
255 |
eos = '<|endoftext|>' # neox eos
|
256 |
-
terminate_response = [start, PreResponse, eos]
|
257 |
-
chat_sep = eos
|
258 |
humanstr = prompt_tokens
|
259 |
botstr = answer_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
261 |
PromptType.open_assistant.name]:
|
262 |
# From added_tokens.json
|
263 |
preprompt = ''
|
264 |
prompt_tokens = "<|prompter|>"
|
265 |
answer_tokens = "<|assistant|>"
|
266 |
-
start =
|
267 |
promptB = promptA = '%s%s' % (preprompt, start)
|
268 |
-
PreInstruct =
|
269 |
PreInput = None
|
270 |
PreResponse = answer_tokens
|
271 |
pend = "<|prefix_end|>"
|
272 |
eos = "</s>"
|
273 |
-
terminate_response = [start, PreResponse, pend, eos]
|
274 |
-
chat_sep = eos
|
275 |
humanstr = prompt_tokens
|
276 |
botstr = answer_tokens
|
|
|
|
|
277 |
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
278 |
PromptType.wizard_lm.name]:
|
279 |
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
@@ -285,7 +339,7 @@ Current Time: {}
|
|
285 |
PreResponse = "\n\n### Response\n"
|
286 |
eos = "</s>"
|
287 |
terminate_response = [PreResponse, eos]
|
288 |
-
chat_sep = eos
|
289 |
humanstr = promptA
|
290 |
botstr = PreResponse
|
291 |
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
@@ -301,13 +355,12 @@ Current Time: {}
|
|
301 |
### Assistant:
|
302 |
"""
|
303 |
terminate_response = [PreResponse]
|
304 |
-
chat_sep = '\n'
|
305 |
humanstr = PreInstruct
|
306 |
botstr = PreResponse
|
307 |
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
308 |
PromptType.instruct_vicuna2.name]:
|
309 |
-
promptA = promptB = "" if not (
|
310 |
-
chat and reduced) else ''
|
311 |
|
312 |
PreInstruct = """
|
313 |
HUMAN:
|
@@ -320,13 +373,12 @@ ASSISTANT:
|
|
320 |
"""
|
321 |
terminate_response = [
|
322 |
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
323 |
-
chat_sep = '\n'
|
324 |
humanstr = PreInstruct
|
325 |
botstr = PreResponse
|
326 |
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
327 |
PromptType.instruct_vicuna3.name]:
|
328 |
-
promptA = promptB = "" if not (
|
329 |
-
chat and reduced) else ''
|
330 |
|
331 |
PreInstruct = """
|
332 |
### User:
|
@@ -339,13 +391,14 @@ ASSISTANT:
|
|
339 |
"""
|
340 |
terminate_response = [
|
341 |
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
342 |
-
chat_sep = '\n'
|
343 |
humanstr = PreInstruct
|
344 |
botstr = PreResponse
|
345 |
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
346 |
PromptType.wizard2.name]:
|
347 |
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
348 |
-
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
|
|
|
349 |
start = ''
|
350 |
promptB = promptA = '%s%s' % (preprompt, start)
|
351 |
PreInstruct = """
|
@@ -356,30 +409,136 @@ ASSISTANT:
|
|
356 |
### Response:
|
357 |
"""
|
358 |
terminate_response = [PreResponse]
|
359 |
-
chat_sep = '\n'
|
360 |
humanstr = PreInstruct
|
361 |
botstr = PreResponse
|
362 |
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
363 |
PromptType.wizard3.name]:
|
364 |
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
365 |
-
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
|
|
366 |
start = ''
|
367 |
promptB = promptA = '%s%s' % (preprompt, start)
|
368 |
PreInstruct = """USER: """
|
369 |
PreInput = None
|
370 |
PreResponse = """ASSISTANT: """
|
371 |
terminate_response = [PreResponse]
|
372 |
-
chat_sep = '\n'
|
373 |
humanstr = PreInstruct
|
374 |
botstr = PreResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
else:
|
377 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
378 |
|
379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
|
|
|
|
|
|
|
|
381 |
|
382 |
-
|
|
|
383 |
context = data_point.get('context')
|
384 |
if context is None:
|
385 |
context = ''
|
@@ -387,11 +546,15 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
|
|
387 |
input = data_point.get('input')
|
388 |
output = data_point.get('output')
|
389 |
prompt_type = data_point.get('prompt_type', prompt_type)
|
|
|
390 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
391 |
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
392 |
-
terminate_response, chat_sep, humanstr, botstr
|
|
|
|
|
393 |
|
394 |
-
|
|
|
395 |
|
396 |
if input and promptA:
|
397 |
prompt += f"""{promptA}"""
|
@@ -400,37 +563,37 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
|
|
400 |
|
401 |
if instruction and PreInstruct is not None and input and PreInput is not None:
|
402 |
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
403 |
-
prompt =
|
404 |
elif instruction and input and PreInstruct is None and PreInput is not None:
|
405 |
prompt += f"""{PreInput}{instruction}
|
406 |
{input}"""
|
407 |
-
prompt =
|
408 |
elif input and instruction and PreInput is None and PreInstruct is not None:
|
409 |
prompt += f"""{PreInstruct}{instruction}
|
410 |
{input}"""
|
411 |
-
prompt =
|
412 |
elif instruction and PreInstruct is not None:
|
413 |
prompt += f"""{PreInstruct}{instruction}"""
|
414 |
-
prompt =
|
415 |
elif input and PreInput is not None:
|
416 |
prompt += f"""{PreInput}{input}"""
|
417 |
-
prompt =
|
418 |
elif input and instruction and PreInput is not None:
|
419 |
prompt += f"""{PreInput}{instruction}{input}"""
|
420 |
-
prompt =
|
421 |
elif input and instruction and PreInstruct is not None:
|
422 |
prompt += f"""{PreInstruct}{instruction}{input}"""
|
423 |
-
prompt =
|
424 |
elif input and instruction:
|
425 |
# i.e. for simple_instruct
|
426 |
prompt += f"""{instruction}: {input}"""
|
427 |
-
prompt =
|
428 |
elif input:
|
429 |
prompt += f"""{input}"""
|
430 |
-
prompt =
|
431 |
elif instruction:
|
432 |
prompt += f"""{instruction}"""
|
433 |
-
prompt =
|
434 |
|
435 |
if PreResponse is not None:
|
436 |
prompt += f"""{PreResponse}"""
|
@@ -441,23 +604,21 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
|
|
441 |
if output:
|
442 |
prompt += f"""{output}"""
|
443 |
|
444 |
-
return prompt, pre_response, terminate_response, chat_sep
|
445 |
|
446 |
|
447 |
-
def
|
448 |
-
if
|
449 |
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
450 |
-
prompt +=
|
451 |
return prompt
|
452 |
|
453 |
|
454 |
class Prompter(object):
|
455 |
-
def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
456 |
allowed_repeat_line_length=10):
|
457 |
self.prompt_type = prompt_type
|
458 |
-
|
459 |
-
_, self.pre_response, self.terminate_response, self.chat_sep = \
|
460 |
-
generate_prompt(data_point, prompt_type, chat, False)
|
461 |
self.debug = debug
|
462 |
self.chat = chat
|
463 |
self.stream_output = stream_output
|
@@ -466,23 +627,41 @@ class Prompter(object):
|
|
466 |
self.prompt = None
|
467 |
context = "" # not for chat context
|
468 |
reduced = False # not for chat context
|
|
|
469 |
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
470 |
-
self.terminate_response, self.chat_sep, self.humanstr, self.botstr
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
if self.debug:
|
477 |
-
print("prompt: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
self.prompt = prompt
|
479 |
return prompt
|
480 |
|
481 |
-
def get_response(self, outputs, prompt=None, sanitize_bot_response=
|
482 |
if isinstance(outputs, str):
|
483 |
outputs = [outputs]
|
484 |
if self.debug:
|
485 |
-
print("output:\n"
|
486 |
if prompt is not None:
|
487 |
self.prompt = prompt
|
488 |
|
@@ -493,7 +672,8 @@ class Prompter(object):
|
|
493 |
if sanitize_bot_response:
|
494 |
from better_profanity import profanity
|
495 |
response = profanity.censor(response)
|
496 |
-
response
|
|
|
497 |
return response
|
498 |
|
499 |
def clean_repeats(response):
|
@@ -515,12 +695,12 @@ class Prompter(object):
|
|
515 |
# then use most basic parsing like pipeline
|
516 |
if self.botstr in output:
|
517 |
if self.humanstr:
|
518 |
-
output = clean_response(output.split(self.botstr)[1].
|
519 |
else:
|
520 |
# i.e. use after bot but only up to next bot
|
521 |
-
output = clean_response(output.split(self.botstr)[1].
|
522 |
else:
|
523 |
-
# output = clean_response(output
|
524 |
# assume just not printed yet
|
525 |
output = ""
|
526 |
else:
|
@@ -547,9 +727,9 @@ class Prompter(object):
|
|
547 |
allow_terminate = True
|
548 |
output = output[len(prompt):]
|
549 |
# clean after subtract prompt out, so correct removal of pre_response
|
550 |
-
output = clean_response(output)
|
551 |
if self.repeat_penalty:
|
552 |
-
output = clean_repeats(output)
|
553 |
if self.terminate_response and allow_terminate:
|
554 |
finds = []
|
555 |
for term in self.terminate_response:
|
@@ -557,11 +737,9 @@ class Prompter(object):
|
|
557 |
finds = [x for x in finds if x >= 0]
|
558 |
if len(finds) > 0:
|
559 |
termi = finds[0]
|
560 |
-
output = output[:termi]
|
561 |
else:
|
562 |
-
output = output
|
563 |
-
else:
|
564 |
-
output = output.strip()
|
565 |
if multi_output:
|
566 |
# prefix with output counter
|
567 |
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
@@ -572,5 +750,5 @@ class Prompter(object):
|
|
572 |
# join all outputs, only one extra new line between outputs
|
573 |
output = '\n'.join(outputs)
|
574 |
if self.debug:
|
575 |
-
print("outputclean:\n"
|
576 |
return output
|
|
|
1 |
+
import os
|
2 |
+
import ast
|
3 |
import time
|
4 |
+
from enums import PromptType # also supports imports from this file from other files
|
5 |
|
6 |
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
prompt_type_to_model_name = {
|
9 |
'plain': [
|
10 |
'EleutherAI/gpt-j-6B',
|
|
|
25 |
'mosaicml/mpt-7b-storywriter',
|
26 |
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
27 |
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
28 |
+
'mosaicml/mpt-30b-instruct', # internal code handles instruct
|
|
|
|
|
29 |
],
|
30 |
+
'gptj': ['gptj', 'gpt4all_llama'],
|
31 |
'prompt_answer': [
|
32 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
33 |
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
34 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
35 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
|
36 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
|
37 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3',
|
38 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
|
39 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
|
40 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
|
41 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
|
42 |
+
'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
|
43 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
|
44 |
+
],
|
45 |
+
'prompt_answer_openllama': [
|
46 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
47 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
48 |
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
49 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
50 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
|
51 |
],
|
52 |
'instruct': [],
|
53 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
|
|
60 |
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
61 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
62 |
'h2oai/h2ogpt-research-oasst1-512-30b',
|
63 |
+
'h2oai/h2ogpt-research-oasst1-llama-65b',
|
64 |
'h2oai/h2ogpt-oasst1-falcon-40b',
|
65 |
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
66 |
],
|
|
|
72 |
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
73 |
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
74 |
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
75 |
+
"instruct_simple": ['JosephusCheung/Guanaco'],
|
76 |
+
"wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
|
77 |
+
"wizard2": ['llama', 'mosaicml/mpt-30b-instruct'],
|
78 |
+
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
79 |
+
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
80 |
}
|
81 |
+
if os.getenv('OPENAI_API_KEY'):
|
82 |
+
prompt_type_to_model_name.update({
|
83 |
+
"openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"],
|
84 |
+
"openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
|
85 |
+
})
|
86 |
|
87 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
88 |
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
|
|
96 |
prompt_types.extend([p.name, p.value, str(p.value)])
|
97 |
|
98 |
|
99 |
+
def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False):
|
100 |
+
prompt_dict_error = ''
|
101 |
+
generates_leading_space = False
|
102 |
+
|
103 |
+
if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
|
104 |
+
try:
|
105 |
+
prompt_dict = ast.literal_eval(prompt_dict)
|
106 |
+
except BaseException as e:
|
107 |
+
prompt_dict_error = str(e)
|
108 |
+
if prompt_dict_error:
|
109 |
+
promptA = None
|
110 |
+
promptB = None
|
111 |
+
PreInstruct = None
|
112 |
+
PreInput = ''
|
113 |
+
PreResponse = ''
|
114 |
+
terminate_response = None
|
115 |
chat_sep = ''
|
116 |
+
chat_turn_sep = ''
|
117 |
humanstr = ''
|
118 |
botstr = ''
|
119 |
+
generates_leading_space = False
|
120 |
+
elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
121 |
+
PromptType.custom.name]:
|
122 |
+
promptA = prompt_dict.get('promptA', '')
|
123 |
+
promptB = prompt_dict('promptB', '')
|
124 |
+
PreInstruct = prompt_dict.get('PreInstruct', '')
|
125 |
+
PreInput = prompt_dict.get('PreInput', '')
|
126 |
+
PreResponse = prompt_dict.get('PreResponse', '')
|
127 |
+
terminate_response = prompt_dict.get('terminate_response', None)
|
128 |
+
chat_sep = prompt_dict.get('chat_sep', '\n')
|
129 |
+
chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n')
|
130 |
+
humanstr = prompt_dict.get('humanstr', '')
|
131 |
+
botstr = prompt_dict.get('botstr', '')
|
132 |
+
elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
133 |
+
PromptType.plain.name]:
|
134 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
135 |
+
terminate_response = []
|
136 |
+
chat_turn_sep = chat_sep = ''
|
137 |
+
# plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token
|
138 |
+
humanstr = None
|
139 |
+
botstr = None
|
140 |
elif prompt_type == 'simple_instruct':
|
141 |
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
142 |
terminate_response = []
|
143 |
+
chat_turn_sep = chat_sep = '\n'
|
144 |
+
humanstr = None
|
145 |
+
botstr = None
|
146 |
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
147 |
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
148 |
str(PromptType.instruct_with_end.value),
|
|
|
168 |
terminate_response = ['### End']
|
169 |
else:
|
170 |
terminate_response = None
|
171 |
+
chat_turn_sep = chat_sep = '\n'
|
172 |
humanstr = PreInstruct
|
173 |
botstr = PreResponse
|
174 |
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
|
|
190 |
### Response:
|
191 |
"""
|
192 |
terminate_response = None
|
193 |
+
chat_turn_sep = chat_sep = '\n'
|
194 |
humanstr = PreInstruct # first thing human says
|
195 |
botstr = PreResponse # first thing bot says
|
196 |
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
|
|
212 |
|
213 |
"""
|
214 |
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
215 |
+
start = ''
|
216 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
217 |
|
218 |
+
PreInstruct = human + ' '
|
219 |
|
220 |
PreInput = None
|
221 |
|
222 |
+
if making_context:
|
223 |
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
224 |
PreResponse = bot + ' '
|
225 |
else:
|
|
|
227 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
228 |
PreResponse = bot
|
229 |
|
230 |
+
terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse]
|
231 |
+
chat_turn_sep = chat_sep = '\n'
|
232 |
humanstr = human # tag before human talks
|
233 |
botstr = bot # tag before bot talks
|
234 |
+
generates_leading_space = True
|
235 |
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
236 |
PromptType.dai_faq.name]:
|
237 |
promptA = ''
|
|
|
247 |
### Driverless AI documentation answer:
|
248 |
"""
|
249 |
terminate_response = ['\n\n']
|
250 |
+
chat_turn_sep = chat_sep = terminate_response
|
251 |
humanstr = PreInstruct
|
252 |
botstr = PreResponse
|
253 |
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
|
|
256 |
PreInstruct = '## Main Text\n\n'
|
257 |
PreResponse = '\n\n## Summary\n\n'
|
258 |
terminate_response = None
|
259 |
+
chat_turn_sep = chat_sep = '\n'
|
260 |
humanstr = PreInstruct
|
261 |
botstr = PreResponse
|
262 |
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
|
|
276 |
"""
|
277 |
terminate_response = [
|
278 |
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
279 |
+
chat_turn_sep = chat_sep = '\n'
|
280 |
humanstr = PreInstruct
|
281 |
botstr = PreResponse
|
282 |
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
|
|
284 |
preprompt = ''
|
285 |
prompt_tokens = "<|prompt|>"
|
286 |
answer_tokens = "<|answer|>"
|
287 |
+
start = ''
|
288 |
promptB = promptA = '%s%s' % (preprompt, start)
|
289 |
+
PreInstruct = prompt_tokens
|
290 |
PreInput = None
|
291 |
PreResponse = answer_tokens
|
292 |
eos = '<|endoftext|>' # neox eos
|
|
|
|
|
293 |
humanstr = prompt_tokens
|
294 |
botstr = answer_tokens
|
295 |
+
terminate_response = [humanstr, PreResponse, eos]
|
296 |
+
chat_sep = ''
|
297 |
+
chat_turn_sep = eos
|
298 |
+
elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
|
299 |
+
PromptType.prompt_answer_openllama.name]:
|
300 |
+
preprompt = ''
|
301 |
+
prompt_tokens = "<|prompt|>"
|
302 |
+
answer_tokens = "<|answer|>"
|
303 |
+
start = ''
|
304 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
305 |
+
PreInstruct = prompt_tokens
|
306 |
+
PreInput = None
|
307 |
+
PreResponse = answer_tokens
|
308 |
+
eos = '</s>' # llama eos
|
309 |
+
humanstr = prompt_tokens
|
310 |
+
botstr = answer_tokens
|
311 |
+
terminate_response = [humanstr, PreResponse, eos]
|
312 |
+
chat_sep = ''
|
313 |
+
chat_turn_sep = eos
|
314 |
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
315 |
PromptType.open_assistant.name]:
|
316 |
# From added_tokens.json
|
317 |
preprompt = ''
|
318 |
prompt_tokens = "<|prompter|>"
|
319 |
answer_tokens = "<|assistant|>"
|
320 |
+
start = ''
|
321 |
promptB = promptA = '%s%s' % (preprompt, start)
|
322 |
+
PreInstruct = prompt_tokens
|
323 |
PreInput = None
|
324 |
PreResponse = answer_tokens
|
325 |
pend = "<|prefix_end|>"
|
326 |
eos = "</s>"
|
|
|
|
|
327 |
humanstr = prompt_tokens
|
328 |
botstr = answer_tokens
|
329 |
+
terminate_response = [humanstr, PreResponse, pend, eos]
|
330 |
+
chat_turn_sep = chat_sep = eos
|
331 |
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
332 |
PromptType.wizard_lm.name]:
|
333 |
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
|
|
339 |
PreResponse = "\n\n### Response\n"
|
340 |
eos = "</s>"
|
341 |
terminate_response = [PreResponse, eos]
|
342 |
+
chat_turn_sep = chat_sep = eos
|
343 |
humanstr = promptA
|
344 |
botstr = PreResponse
|
345 |
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
|
|
355 |
### Assistant:
|
356 |
"""
|
357 |
terminate_response = [PreResponse]
|
358 |
+
chat_turn_sep = chat_sep = '\n'
|
359 |
humanstr = PreInstruct
|
360 |
botstr = PreResponse
|
361 |
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
362 |
PromptType.instruct_vicuna2.name]:
|
363 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
|
|
364 |
|
365 |
PreInstruct = """
|
366 |
HUMAN:
|
|
|
373 |
"""
|
374 |
terminate_response = [
|
375 |
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
376 |
+
chat_turn_sep = chat_sep = '\n'
|
377 |
humanstr = PreInstruct
|
378 |
botstr = PreResponse
|
379 |
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
380 |
PromptType.instruct_vicuna3.name]:
|
381 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
|
|
382 |
|
383 |
PreInstruct = """
|
384 |
### User:
|
|
|
391 |
"""
|
392 |
terminate_response = [
|
393 |
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
394 |
+
chat_turn_sep = chat_sep = '\n'
|
395 |
humanstr = PreInstruct
|
396 |
botstr = PreResponse
|
397 |
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
398 |
PromptType.wizard2.name]:
|
399 |
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
400 |
+
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not (
|
401 |
+
chat and reduced) else ''
|
402 |
start = ''
|
403 |
promptB = promptA = '%s%s' % (preprompt, start)
|
404 |
PreInstruct = """
|
|
|
409 |
### Response:
|
410 |
"""
|
411 |
terminate_response = [PreResponse]
|
412 |
+
chat_turn_sep = chat_sep = '\n'
|
413 |
humanstr = PreInstruct
|
414 |
botstr = PreResponse
|
415 |
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
416 |
PromptType.wizard3.name]:
|
417 |
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
418 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not (
|
419 |
+
chat and reduced) else ''
|
420 |
start = ''
|
421 |
promptB = promptA = '%s%s' % (preprompt, start)
|
422 |
PreInstruct = """USER: """
|
423 |
PreInput = None
|
424 |
PreResponse = """ASSISTANT: """
|
425 |
terminate_response = [PreResponse]
|
426 |
+
chat_turn_sep = chat_sep = '\n'
|
427 |
humanstr = PreInstruct
|
428 |
botstr = PreResponse
|
429 |
+
elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value),
|
430 |
+
PromptType.wizard_vicuna.name]:
|
431 |
+
preprompt = ''
|
432 |
+
start = ''
|
433 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
434 |
+
PreInstruct = """USER: """
|
435 |
+
PreInput = None
|
436 |
+
PreResponse = """ASSISTANT: """
|
437 |
+
terminate_response = [PreResponse]
|
438 |
+
chat_turn_sep = chat_sep = '\n'
|
439 |
+
humanstr = PreInstruct
|
440 |
+
botstr = PreResponse
|
441 |
+
|
442 |
+
elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
|
443 |
+
PromptType.instruct_simple.name]:
|
444 |
+
promptB = promptA = '' if not (chat and reduced) else ''
|
445 |
|
446 |
+
PreInstruct = """
|
447 |
+
### Instruction:
|
448 |
+
"""
|
449 |
+
|
450 |
+
PreInput = """
|
451 |
+
### Input:
|
452 |
+
"""
|
453 |
+
|
454 |
+
PreResponse = """
|
455 |
+
### Response:
|
456 |
+
"""
|
457 |
+
terminate_response = None
|
458 |
+
chat_turn_sep = chat_sep = '\n'
|
459 |
+
humanstr = PreInstruct
|
460 |
+
botstr = PreResponse
|
461 |
+
elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value),
|
462 |
+
PromptType.openai.name]:
|
463 |
+
preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not (
|
464 |
+
chat and reduced) else ''
|
465 |
+
start = ''
|
466 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
467 |
+
PreInstruct = "\nHuman: "
|
468 |
+
PreInput = None
|
469 |
+
PreResponse = "\nAI:"
|
470 |
+
terminate_response = [PreResponse] + [" Human:", " AI:"]
|
471 |
+
chat_turn_sep = chat_sep = '\n'
|
472 |
+
humanstr = PreInstruct
|
473 |
+
botstr = PreResponse
|
474 |
+
elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value),
|
475 |
+
PromptType.gptj.name]:
|
476 |
+
preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not (
|
477 |
+
chat and reduced) else ''
|
478 |
+
start = ''
|
479 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
480 |
+
PreInstruct = "\n### Prompt: "
|
481 |
+
PreInput = None
|
482 |
+
PreResponse = "\n### Response: "
|
483 |
+
terminate_response = [PreResponse] + ["Prompt:", "Response:"]
|
484 |
+
chat_turn_sep = chat_sep = '\n'
|
485 |
+
humanstr = PreInstruct
|
486 |
+
botstr = PreResponse
|
487 |
+
elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
|
488 |
+
PromptType.openai_chat.name]:
|
489 |
+
# prompting and termination all handled by endpoint
|
490 |
+
preprompt = """"""
|
491 |
+
start = ''
|
492 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
493 |
+
PreInstruct = ""
|
494 |
+
PreInput = None
|
495 |
+
PreResponse = ""
|
496 |
+
terminate_response = []
|
497 |
+
chat_turn_sep = chat_sep = '\n'
|
498 |
+
humanstr = None
|
499 |
+
botstr = None
|
500 |
+
elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value),
|
501 |
+
PromptType.vicuna11.name]:
|
502 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not (
|
503 |
+
chat and reduced) else ''
|
504 |
+
start = ''
|
505 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
506 |
+
eos = '</s>'
|
507 |
+
PreInstruct = """USER: """
|
508 |
+
PreInput = None
|
509 |
+
PreResponse = """ASSISTANT:"""
|
510 |
+
terminate_response = [PreResponse]
|
511 |
+
chat_sep = ' '
|
512 |
+
chat_turn_sep = eos
|
513 |
+
humanstr = PreInstruct
|
514 |
+
botstr = PreResponse
|
515 |
+
|
516 |
+
if making_context:
|
517 |
+
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
518 |
+
PreResponse = PreResponse + ' '
|
519 |
+
else:
|
520 |
+
# normally LLM adds space after this, because was how trained.
|
521 |
+
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
522 |
+
PreResponse = PreResponse
|
523 |
else:
|
524 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
525 |
|
526 |
+
if isinstance(terminate_response, (tuple, list)):
|
527 |
+
assert '' not in terminate_response, "Bad terminate_response"
|
528 |
+
|
529 |
+
ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
|
530 |
+
PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
|
531 |
+
chat_turn_sep=chat_turn_sep,
|
532 |
+
humanstr=humanstr, botstr=botstr,
|
533 |
+
generates_leading_space=generates_leading_space)
|
534 |
|
535 |
+
if return_dict:
|
536 |
+
return ret_dict, prompt_dict_error
|
537 |
+
else:
|
538 |
+
return tuple(list(ret_dict.values()))
|
539 |
|
540 |
+
|
541 |
+
def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context):
|
542 |
context = data_point.get('context')
|
543 |
if context is None:
|
544 |
context = ''
|
|
|
546 |
input = data_point.get('input')
|
547 |
output = data_point.get('output')
|
548 |
prompt_type = data_point.get('prompt_type', prompt_type)
|
549 |
+
prompt_dict = data_point.get('prompt_dict', prompt_dict)
|
550 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
551 |
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
552 |
+
terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \
|
553 |
+
generates_leading_space = get_prompt(prompt_type, prompt_dict, chat,
|
554 |
+
context, reduced, making_context)
|
555 |
|
556 |
+
# could avoid if reduce=True, but too complex for parent functions to handle
|
557 |
+
prompt = context
|
558 |
|
559 |
if input and promptA:
|
560 |
prompt += f"""{promptA}"""
|
|
|
563 |
|
564 |
if instruction and PreInstruct is not None and input and PreInput is not None:
|
565 |
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
566 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
567 |
elif instruction and input and PreInstruct is None and PreInput is not None:
|
568 |
prompt += f"""{PreInput}{instruction}
|
569 |
{input}"""
|
570 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
571 |
elif input and instruction and PreInput is None and PreInstruct is not None:
|
572 |
prompt += f"""{PreInstruct}{instruction}
|
573 |
{input}"""
|
574 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
575 |
elif instruction and PreInstruct is not None:
|
576 |
prompt += f"""{PreInstruct}{instruction}"""
|
577 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
578 |
elif input and PreInput is not None:
|
579 |
prompt += f"""{PreInput}{input}"""
|
580 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
581 |
elif input and instruction and PreInput is not None:
|
582 |
prompt += f"""{PreInput}{instruction}{input}"""
|
583 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
584 |
elif input and instruction and PreInstruct is not None:
|
585 |
prompt += f"""{PreInstruct}{instruction}{input}"""
|
586 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
587 |
elif input and instruction:
|
588 |
# i.e. for simple_instruct
|
589 |
prompt += f"""{instruction}: {input}"""
|
590 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
591 |
elif input:
|
592 |
prompt += f"""{input}"""
|
593 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
594 |
elif instruction:
|
595 |
prompt += f"""{instruction}"""
|
596 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
597 |
|
598 |
if PreResponse is not None:
|
599 |
prompt += f"""{PreResponse}"""
|
|
|
604 |
if output:
|
605 |
prompt += f"""{output}"""
|
606 |
|
607 |
+
return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep
|
608 |
|
609 |
|
610 |
+
def inject_chatsep(prompt_type, prompt, chat_sep=None):
|
611 |
+
if chat_sep:
|
612 |
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
613 |
+
prompt += chat_sep
|
614 |
return prompt
|
615 |
|
616 |
|
617 |
class Prompter(object):
|
618 |
+
def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
619 |
allowed_repeat_line_length=10):
|
620 |
self.prompt_type = prompt_type
|
621 |
+
self.prompt_dict = prompt_dict
|
|
|
|
|
622 |
self.debug = debug
|
623 |
self.chat = chat
|
624 |
self.stream_output = stream_output
|
|
|
627 |
self.prompt = None
|
628 |
context = "" # not for chat context
|
629 |
reduced = False # not for chat context
|
630 |
+
making_context = False # not for chat context
|
631 |
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
632 |
+
self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \
|
633 |
+
self.generates_leading_space = \
|
634 |
+
get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context)
|
635 |
+
self.pre_response = self.PreResponse
|
636 |
+
|
637 |
+
def generate_prompt(self, data_point, reduced=None):
|
638 |
+
"""
|
639 |
+
data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt
|
640 |
+
:param data_point:
|
641 |
+
:param reduced:
|
642 |
+
:return:
|
643 |
+
"""
|
644 |
+
reduced = data_point.get('context') not in ['', None] if reduced is None else reduced
|
645 |
+
making_context = False # whether really making final prompt or just generating context
|
646 |
+
prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced,
|
647 |
+
making_context)
|
648 |
if self.debug:
|
649 |
+
print("prompt: %s" % prompt, flush=True)
|
650 |
+
# if have context, should have always reduced and only preappend promptA/B here
|
651 |
+
if data_point.get('context'):
|
652 |
+
if data_point.get('input') and self.promptA:
|
653 |
+
prompt = self.promptA + prompt
|
654 |
+
elif self.promptB:
|
655 |
+
prompt = self.promptB + prompt
|
656 |
+
|
657 |
self.prompt = prompt
|
658 |
return prompt
|
659 |
|
660 |
+
def get_response(self, outputs, prompt=None, sanitize_bot_response=False):
|
661 |
if isinstance(outputs, str):
|
662 |
outputs = [outputs]
|
663 |
if self.debug:
|
664 |
+
print("output:\n%s" % '\n\n'.join(outputs), flush=True)
|
665 |
if prompt is not None:
|
666 |
self.prompt = prompt
|
667 |
|
|
|
672 |
if sanitize_bot_response:
|
673 |
from better_profanity import profanity
|
674 |
response = profanity.censor(response)
|
675 |
+
if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ':
|
676 |
+
response = response[1:]
|
677 |
return response
|
678 |
|
679 |
def clean_repeats(response):
|
|
|
695 |
# then use most basic parsing like pipeline
|
696 |
if self.botstr in output:
|
697 |
if self.humanstr:
|
698 |
+
output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
|
699 |
else:
|
700 |
# i.e. use after bot but only up to next bot
|
701 |
+
output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0])
|
702 |
else:
|
703 |
+
# output = clean_response(output)
|
704 |
# assume just not printed yet
|
705 |
output = ""
|
706 |
else:
|
|
|
727 |
allow_terminate = True
|
728 |
output = output[len(prompt):]
|
729 |
# clean after subtract prompt out, so correct removal of pre_response
|
730 |
+
output = clean_response(output)
|
731 |
if self.repeat_penalty:
|
732 |
+
output = clean_repeats(output)
|
733 |
if self.terminate_response and allow_terminate:
|
734 |
finds = []
|
735 |
for term in self.terminate_response:
|
|
|
737 |
finds = [x for x in finds if x >= 0]
|
738 |
if len(finds) > 0:
|
739 |
termi = finds[0]
|
740 |
+
output = output[:termi]
|
741 |
else:
|
742 |
+
output = output
|
|
|
|
|
743 |
if multi_output:
|
744 |
# prefix with output counter
|
745 |
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
|
|
750 |
# join all outputs, only one extra new line between outputs
|
751 |
output = '\n'.join(outputs)
|
752 |
if self.debug:
|
753 |
+
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
754 |
return output
|
requirements.txt
CHANGED
@@ -1,50 +1,50 @@
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
-
datasets==2.
|
3 |
-
sentencepiece==0.1.
|
4 |
-
gradio==3.
|
5 |
-
huggingface_hub==0.
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
-
docutils==0.
|
9 |
torch==2.0.1
|
10 |
evaluate==0.4.0
|
11 |
rouge_score==0.1.2
|
12 |
sacrebleu==2.3.1
|
13 |
scikit-learn==1.2.2
|
14 |
alt-profanity-check==1.2.2
|
15 |
-
better-profanity==0.
|
16 |
-
numpy==1.24.
|
17 |
-
pandas==2.0.
|
18 |
matplotlib==3.7.1
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
-
accelerate==0.
|
22 |
-
git+https://github.com/huggingface/peft.git@
|
23 |
-
transformers==4.
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
26 |
|
27 |
# optional for generate
|
28 |
pynvml==11.5.0
|
29 |
-
psutil==5.9.
|
30 |
boto3==1.26.101
|
31 |
botocore==1.29.101
|
32 |
|
33 |
# optional for finetune
|
34 |
-
tensorboard==2.
|
35 |
-
neptune==1.
|
36 |
|
37 |
# for gradio client
|
38 |
-
gradio_client==0.2.
|
39 |
beautifulsoup4==4.12.2
|
40 |
-
markdown==3.4.
|
41 |
|
42 |
# data and testing
|
43 |
pytest==7.2.2
|
44 |
pytest-xdist==3.2.1
|
45 |
nltk==3.8.1
|
46 |
textstat==0.7.3
|
47 |
-
pandoc==2.3
|
48 |
#pypandoc==1.11
|
49 |
pypandoc_binary==1.11
|
50 |
openpyxl==3.1.2
|
@@ -53,17 +53,66 @@ bioc==2.0
|
|
53 |
|
54 |
# falcon
|
55 |
einops==0.6.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# optional for chat with PDF
|
57 |
-
langchain==0.0.
|
58 |
-
pypdf==3.
|
59 |
-
tiktoken==0.3.3
|
60 |
# avoid textract, requires old six
|
61 |
#textract==1.6.5
|
62 |
|
63 |
# for HF embeddings
|
64 |
sentence_transformers==2.2.2
|
65 |
-
# for OpenAI embeddings (requires key)
|
66 |
-
openai==0.27.6
|
67 |
|
68 |
# local vector db
|
69 |
chromadb==0.3.25
|
@@ -75,14 +124,14 @@ chromadb==0.3.25
|
|
75 |
|
76 |
# strong support for images
|
77 |
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
78 |
-
unstructured[local-inference]==0.
|
79 |
#pdf2image==1.16.3
|
80 |
#pytesseract==0.3.10
|
81 |
pillow
|
82 |
|
83 |
pdfminer.six==20221105
|
84 |
-
urllib3
|
85 |
-
requests_file
|
86 |
|
87 |
#pdf2image==1.16.3
|
88 |
#pytesseract==0.3.10
|
@@ -97,4 +146,8 @@ tabulate==0.9.0
|
|
97 |
pip-licenses==4.3.0
|
98 |
|
99 |
# weaviate vector db
|
100 |
-
weaviate-client==3.
|
|
|
|
|
|
|
|
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
+
datasets==2.13.0
|
3 |
+
sentencepiece==0.1.99
|
4 |
+
gradio==3.35.2
|
5 |
+
huggingface_hub==0.15.1
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
+
docutils==0.20.1
|
9 |
torch==2.0.1
|
10 |
evaluate==0.4.0
|
11 |
rouge_score==0.1.2
|
12 |
sacrebleu==2.3.1
|
13 |
scikit-learn==1.2.2
|
14 |
alt-profanity-check==1.2.2
|
15 |
+
better-profanity==0.7.0
|
16 |
+
numpy==1.24.3
|
17 |
+
pandas==2.0.2
|
18 |
matplotlib==3.7.1
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
+
accelerate==0.20.3
|
22 |
+
git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
|
23 |
+
transformers==4.30.2
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
26 |
|
27 |
# optional for generate
|
28 |
pynvml==11.5.0
|
29 |
+
psutil==5.9.5
|
30 |
boto3==1.26.101
|
31 |
botocore==1.29.101
|
32 |
|
33 |
# optional for finetune
|
34 |
+
tensorboard==2.13.0
|
35 |
+
neptune==1.2.0
|
36 |
|
37 |
# for gradio client
|
38 |
+
gradio_client==0.2.7
|
39 |
beautifulsoup4==4.12.2
|
40 |
+
markdown==3.4.3
|
41 |
|
42 |
# data and testing
|
43 |
pytest==7.2.2
|
44 |
pytest-xdist==3.2.1
|
45 |
nltk==3.8.1
|
46 |
textstat==0.7.3
|
47 |
+
# pandoc==2.3
|
48 |
#pypandoc==1.11
|
49 |
pypandoc_binary==1.11
|
50 |
openpyxl==3.1.2
|
|
|
53 |
|
54 |
# falcon
|
55 |
einops==0.6.1
|
56 |
+
instructorembedding==1.0.1
|
57 |
+
|
58 |
+
# for gpt4all .env file, but avoid worrying about imports
|
59 |
+
python-dotenv==1.0.0
|
60 |
+
|
61 |
+
text-generation==0.6.0
|
62 |
+
# for tokenization when don't have HF tokenizer
|
63 |
+
tiktoken==0.4.0
|
64 |
+
# optional: for OpenAI endpoint or embeddings (requires key)
|
65 |
+
openai==0.27.8
|
66 |
+
# optional for chat with PDF
|
67 |
+
langchain==0.0.202
|
68 |
+
pypdf==3.9.1
|
69 |
+
# avoid textract, requires old six
|
70 |
+
#textract==1.6.5
|
71 |
+
|
72 |
+
# for HF embeddings
|
73 |
+
sentence_transformers==2.2.2
|
74 |
+
|
75 |
+
# local vector db
|
76 |
+
chromadb==0.3.25
|
77 |
+
# server vector db
|
78 |
+
#pymilvus==2.2.8
|
79 |
+
|
80 |
+
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
81 |
+
# unstructured==0.6.6
|
82 |
+
|
83 |
+
# strong support for images
|
84 |
+
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
85 |
+
unstructured[local-inference]==0.7.4
|
86 |
+
#pdf2image==1.16.3
|
87 |
+
#pytesseract==0.3.10
|
88 |
+
pillow
|
89 |
+
|
90 |
+
pdfminer.six==20221105
|
91 |
+
urllib3
|
92 |
+
requests_file
|
93 |
+
|
94 |
+
#pdf2image==1.16.3
|
95 |
+
#pytesseract==0.3.10
|
96 |
+
tabulate==0.9.0
|
97 |
+
# FYI pandoc already part of requirements.txt
|
98 |
+
|
99 |
+
# JSONLoader, but makes some trouble for some users
|
100 |
+
# jq==1.4.1
|
101 |
+
|
102 |
+
# to check licenses
|
103 |
+
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
104 |
+
pip-licenses==4.3.0
|
105 |
+
|
106 |
+
# weaviate vector db
|
107 |
+
weaviate-client==3.20.0
|
108 |
# optional for chat with PDF
|
109 |
+
langchain==0.0.202
|
110 |
+
pypdf==3.9.1
|
|
|
111 |
# avoid textract, requires old six
|
112 |
#textract==1.6.5
|
113 |
|
114 |
# for HF embeddings
|
115 |
sentence_transformers==2.2.2
|
|
|
|
|
116 |
|
117 |
# local vector db
|
118 |
chromadb==0.3.25
|
|
|
124 |
|
125 |
# strong support for images
|
126 |
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
127 |
+
unstructured[local-inference]==0.7.4
|
128 |
#pdf2image==1.16.3
|
129 |
#pytesseract==0.3.10
|
130 |
pillow
|
131 |
|
132 |
pdfminer.six==20221105
|
133 |
+
urllib3
|
134 |
+
requests_file
|
135 |
|
136 |
#pdf2image==1.16.3
|
137 |
#pytesseract==0.3.10
|
|
|
146 |
pip-licenses==4.3.0
|
147 |
|
148 |
# weaviate vector db
|
149 |
+
weaviate-client==3.20.0
|
150 |
+
faiss-gpu==1.7.2
|
151 |
+
arxiv==1.4.7
|
152 |
+
pymupdf==1.22.3 # AGPL license
|
153 |
+
# extract-msg==0.41.1 # GPL3
|
stopping.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
import torch
|
2 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
3 |
|
4 |
-
from
|
5 |
|
6 |
|
7 |
class StoppingCriteriaSub(StoppingCriteria):
|
8 |
|
9 |
-
def __init__(self, stops=[], encounters=[], device="cuda"):
|
10 |
super().__init__()
|
11 |
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
12 |
self.encounters = encounters
|
13 |
self.stops = [stop.to(device) for stop in stops]
|
14 |
self.num_stops = [0] * len(stops)
|
|
|
15 |
|
16 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
17 |
for stopi, stop in enumerate(self.stops):
|
@@ -20,12 +21,16 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
20 |
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
21 |
# print("Stopped", flush=True)
|
22 |
return True
|
|
|
|
|
|
|
23 |
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
24 |
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
25 |
return False
|
26 |
|
27 |
|
28 |
-
def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
|
|
|
29 |
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
30 |
if prompt_type == PromptType.human_bot.name:
|
31 |
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
@@ -66,7 +71,8 @@ def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:")
|
|
66 |
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
67 |
# build stopper
|
68 |
stopping_criteria = StoppingCriteriaList(
|
69 |
-
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device
|
|
|
70 |
else:
|
71 |
stopping_criteria = StoppingCriteriaList()
|
72 |
return stopping_criteria
|
|
|
1 |
import torch
|
2 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
3 |
|
4 |
+
from enums import PromptType
|
5 |
|
6 |
|
7 |
class StoppingCriteriaSub(StoppingCriteria):
|
8 |
|
9 |
+
def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
|
10 |
super().__init__()
|
11 |
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
12 |
self.encounters = encounters
|
13 |
self.stops = [stop.to(device) for stop in stops]
|
14 |
self.num_stops = [0] * len(stops)
|
15 |
+
self.model_max_length = model_max_length
|
16 |
|
17 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
18 |
for stopi, stop in enumerate(self.stops):
|
|
|
21 |
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
22 |
# print("Stopped", flush=True)
|
23 |
return True
|
24 |
+
if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
|
25 |
+
# critical limit
|
26 |
+
return True
|
27 |
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
28 |
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
29 |
return False
|
30 |
|
31 |
|
32 |
+
def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
|
33 |
+
# FIXME: prompt_dict unused currently
|
34 |
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
35 |
if prompt_type == PromptType.human_bot.name:
|
36 |
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
|
|
71 |
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
72 |
# build stopper
|
73 |
stopping_criteria = StoppingCriteriaList(
|
74 |
+
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
|
75 |
+
model_max_length=model_max_length)])
|
76 |
else:
|
77 |
stopping_criteria = StoppingCriteriaList()
|
78 |
return stopping_criteria
|
utils.py
CHANGED
@@ -14,6 +14,7 @@ import time
|
|
14 |
import traceback
|
15 |
import zipfile
|
16 |
from datetime import datetime
|
|
|
17 |
import filelock
|
18 |
import requests, uuid
|
19 |
from typing import Tuple, Callable, Dict
|
@@ -68,6 +69,25 @@ def ping():
|
|
68 |
pass
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def get_torch_allocated():
|
72 |
import torch
|
73 |
return torch.cuda.memory_allocated()
|
@@ -97,27 +117,29 @@ def system_info():
|
|
97 |
system['CPU_C/%s' % k] = v
|
98 |
|
99 |
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
121 |
system['hash'] = get_githash()
|
122 |
|
123 |
return system
|
@@ -166,35 +188,39 @@ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
|
166 |
return zip_file, zip_file
|
167 |
|
168 |
|
169 |
-
def save_generate_output(output=None, base_model=None, save_dir=None
|
|
|
170 |
try:
|
171 |
-
return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir
|
|
|
172 |
except Exception as e:
|
173 |
traceback.print_exc()
|
174 |
print('Exception in saving: %s' % str(e))
|
175 |
|
176 |
|
177 |
-
def _save_generate_output(output=None, base_model=None, save_dir=None
|
|
|
178 |
"""
|
179 |
Save conversation to .json, row by row.
|
180 |
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
181 |
Appends if file exists
|
182 |
"""
|
|
|
|
|
183 |
assert save_dir, "save_dir must be provided"
|
184 |
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
185 |
raise RuntimeError("save_dir already exists and is not a directory!")
|
186 |
os.makedirs(save_dir, exist_ok=True)
|
187 |
import json
|
188 |
-
|
189 |
-
|
190 |
-
output = output[:-10]
|
191 |
with filelock.FileLock("save_dir.lock"):
|
192 |
# lock logging in case have concurrency
|
193 |
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
194 |
# just add [ at start, and ] at end, and have proper JSON dataset
|
195 |
f.write(
|
196 |
" " + json.dumps(
|
197 |
-
|
198 |
) + ",\n"
|
199 |
)
|
200 |
|
@@ -800,6 +826,7 @@ def get_kwargs(func, exclude_names=None, **kwargs):
|
|
800 |
|
801 |
|
802 |
import pkg_resources
|
|
|
803 |
have_faiss = False
|
804 |
|
805 |
try:
|
@@ -827,7 +854,7 @@ def hash_file(file):
|
|
827 |
BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
|
828 |
|
829 |
md5 = hashlib.md5()
|
830 |
-
#sha1 = hashlib.sha1()
|
831 |
|
832 |
with open(file, 'rb') as f:
|
833 |
while True:
|
@@ -835,9 +862,67 @@ def hash_file(file):
|
|
835 |
if not data:
|
836 |
break
|
837 |
md5.update(data)
|
838 |
-
#sha1.update(data)
|
839 |
except BaseException as e:
|
840 |
print("Cannot hash %s due to %s" % (file, str(e)))
|
841 |
traceback.print_exc()
|
842 |
md5 = None
|
843 |
return md5.hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
import traceback
|
15 |
import zipfile
|
16 |
from datetime import datetime
|
17 |
+
|
18 |
import filelock
|
19 |
import requests, uuid
|
20 |
from typing import Tuple, Callable, Dict
|
|
|
69 |
pass
|
70 |
|
71 |
|
72 |
+
def ping_gpu():
|
73 |
+
try:
|
74 |
+
print('Ping_GPU: %s %s' % (str(datetime.now()), system_info()), flush=True)
|
75 |
+
except AttributeError:
|
76 |
+
# some programs wrap print and will fail with flush passed
|
77 |
+
pass
|
78 |
+
try:
|
79 |
+
ping_gpu_memory()
|
80 |
+
except Exception as e:
|
81 |
+
print('Ping_GPU memory failure: %s' % str(e), flush=True)
|
82 |
+
|
83 |
+
|
84 |
+
def ping_gpu_memory():
|
85 |
+
from models.gpu_mem_track import MemTracker
|
86 |
+
gpu_tracker = MemTracker() # define a GPU tracker
|
87 |
+
from torch.cuda import memory_summary
|
88 |
+
gpu_tracker.track()
|
89 |
+
|
90 |
+
|
91 |
def get_torch_allocated():
|
92 |
import torch
|
93 |
return torch.cuda.memory_allocated()
|
|
|
117 |
system['CPU_C/%s' % k] = v
|
118 |
|
119 |
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
120 |
+
try:
|
121 |
+
from pynvml.smi import nvidia_smi
|
122 |
+
nvsmi = nvidia_smi.getInstance()
|
123 |
+
|
124 |
+
gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
|
125 |
+
enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
|
126 |
+
for k, v in gpu_power_dict.items():
|
127 |
+
system['GPU_W/%s' % k] = v
|
128 |
+
|
129 |
+
gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
|
130 |
+
enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
|
131 |
+
for k, v in gpu_temp_dict.items():
|
132 |
+
system['GPU_C/%s' % k] = v
|
133 |
+
|
134 |
+
gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
|
135 |
+
enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
|
136 |
+
gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
|
137 |
+
enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
|
138 |
+
gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
|
139 |
+
for k, v in gpu_memory_frac_dict.items():
|
140 |
+
system[f'GPU_M/%s' % k] = v
|
141 |
+
except ModuleNotFoundError:
|
142 |
+
pass
|
143 |
system['hash'] = get_githash()
|
144 |
|
145 |
return system
|
|
|
188 |
return zip_file, zip_file
|
189 |
|
190 |
|
191 |
+
def save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
|
192 |
+
extra_dict={}):
|
193 |
try:
|
194 |
+
return _save_generate_output(prompt=prompt, output=output, base_model=base_model, save_dir=save_dir,
|
195 |
+
where_from=where_from, extra_dict=extra_dict)
|
196 |
except Exception as e:
|
197 |
traceback.print_exc()
|
198 |
print('Exception in saving: %s' % str(e))
|
199 |
|
200 |
|
201 |
+
def _save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
|
202 |
+
extra_dict={}):
|
203 |
"""
|
204 |
Save conversation to .json, row by row.
|
205 |
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
206 |
Appends if file exists
|
207 |
"""
|
208 |
+
prompt = '<not set>' if prompt is None else prompt
|
209 |
+
output = '<not set>' if output is None else output
|
210 |
assert save_dir, "save_dir must be provided"
|
211 |
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
212 |
raise RuntimeError("save_dir already exists and is not a directory!")
|
213 |
os.makedirs(save_dir, exist_ok=True)
|
214 |
import json
|
215 |
+
dict_to_save = dict(prompt=prompt, text=output, time=time.ctime(), base_model=base_model, where_from=where_from)
|
216 |
+
dict_to_save.update(extra_dict)
|
|
|
217 |
with filelock.FileLock("save_dir.lock"):
|
218 |
# lock logging in case have concurrency
|
219 |
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
220 |
# just add [ at start, and ] at end, and have proper JSON dataset
|
221 |
f.write(
|
222 |
" " + json.dumps(
|
223 |
+
dict_to_save
|
224 |
) + ",\n"
|
225 |
)
|
226 |
|
|
|
826 |
|
827 |
|
828 |
import pkg_resources
|
829 |
+
|
830 |
have_faiss = False
|
831 |
|
832 |
try:
|
|
|
854 |
BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
|
855 |
|
856 |
md5 = hashlib.md5()
|
857 |
+
# sha1 = hashlib.sha1()
|
858 |
|
859 |
with open(file, 'rb') as f:
|
860 |
while True:
|
|
|
862 |
if not data:
|
863 |
break
|
864 |
md5.update(data)
|
865 |
+
# sha1.update(data)
|
866 |
except BaseException as e:
|
867 |
print("Cannot hash %s due to %s" % (file, str(e)))
|
868 |
traceback.print_exc()
|
869 |
md5 = None
|
870 |
return md5.hexdigest()
|
871 |
+
|
872 |
+
|
873 |
+
def start_faulthandler():
|
874 |
+
# If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
|
875 |
+
# If more than one fork tries to write at same time, then looks corrupted.
|
876 |
+
import faulthandler
|
877 |
+
|
878 |
+
# SIGUSR1 in h2oai/__init__.py as well
|
879 |
+
faulthandler.enable()
|
880 |
+
if hasattr(faulthandler, 'register'):
|
881 |
+
# windows/mac
|
882 |
+
import signal
|
883 |
+
faulthandler.register(signal.SIGUSR1)
|
884 |
+
|
885 |
+
|
886 |
+
def get_hf_server(inference_server):
|
887 |
+
inf_split = inference_server.split(" ")
|
888 |
+
assert len(inf_split) == 1 or len(inf_split) == 3
|
889 |
+
inference_server = inf_split[0]
|
890 |
+
if len(inf_split) == 3:
|
891 |
+
headers = {"authorization": "%s %s" % (inf_split[1], inf_split[2])}
|
892 |
+
else:
|
893 |
+
headers = None
|
894 |
+
return inference_server, headers
|
895 |
+
|
896 |
+
|
897 |
+
class FakeTokenizer:
|
898 |
+
"""
|
899 |
+
1) For keeping track of model_max_length
|
900 |
+
2) For when model doesn't directly expose tokenizer but need to count tokens
|
901 |
+
"""
|
902 |
+
|
903 |
+
def __init__(self, model_max_length=2048, encoding_name="cl100k_base"):
|
904 |
+
# dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250
|
905 |
+
self.model_max_length = model_max_length - 250
|
906 |
+
self.encoding_name = encoding_name
|
907 |
+
# The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection.
|
908 |
+
import tiktoken
|
909 |
+
self.encoding = tiktoken.get_encoding(self.encoding_name)
|
910 |
+
|
911 |
+
def encode(self, x, *args, return_tensors="pt", **kwargs):
|
912 |
+
input_ids = self.encoding.encode(x, disallowed_special=())
|
913 |
+
if return_tensors == 'pt' and isinstance(input_ids, list):
|
914 |
+
import torch
|
915 |
+
input_ids = torch.tensor(input_ids)
|
916 |
+
return dict(input_ids=input_ids)
|
917 |
+
|
918 |
+
def decode(self, x, *args, **kwargs):
|
919 |
+
# input is input_ids[0] form
|
920 |
+
return self.encoding.decode(x)
|
921 |
+
|
922 |
+
def num_tokens_from_string(self, prompt: str) -> int:
|
923 |
+
"""Returns the number of tokens in a text string."""
|
924 |
+
num_tokens = len(self.encoding.encode(prompt))
|
925 |
+
return num_tokens
|
926 |
+
|
927 |
+
def __call__(self, x, *args, **kwargs):
|
928 |
+
return self.encode(x, *args, **kwargs)
|
utils_langchain.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Union, Optional
|
2 |
+
import time
|
3 |
+
import queue
|
4 |
+
|
5 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
6 |
+
from langchain.schema import LLMResult
|
7 |
+
|
8 |
+
|
9 |
+
class StreamingGradioCallbackHandler(BaseCallbackHandler):
|
10 |
+
"""
|
11 |
+
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
|
12 |
+
"""
|
13 |
+
def __init__(self, timeout: Optional[float] = None, block=True):
|
14 |
+
super().__init__()
|
15 |
+
self.text_queue = queue.SimpleQueue()
|
16 |
+
self.stop_signal = None
|
17 |
+
self.do_stop = False
|
18 |
+
self.timeout = timeout
|
19 |
+
self.block = block
|
20 |
+
|
21 |
+
def on_llm_start(
|
22 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
23 |
+
) -> None:
|
24 |
+
"""Run when LLM starts running. Clean the queue."""
|
25 |
+
while not self.text_queue.empty():
|
26 |
+
try:
|
27 |
+
self.text_queue.get(block=False)
|
28 |
+
except queue.Empty:
|
29 |
+
continue
|
30 |
+
|
31 |
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
32 |
+
"""Run on new LLM token. Only available when streaming is enabled."""
|
33 |
+
self.text_queue.put(token)
|
34 |
+
|
35 |
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
36 |
+
"""Run when LLM ends running."""
|
37 |
+
self.text_queue.put(self.stop_signal)
|
38 |
+
|
39 |
+
def on_llm_error(
|
40 |
+
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
41 |
+
) -> None:
|
42 |
+
"""Run when LLM errors."""
|
43 |
+
self.text_queue.put(self.stop_signal)
|
44 |
+
|
45 |
+
def __iter__(self):
|
46 |
+
return self
|
47 |
+
|
48 |
+
def __next__(self):
|
49 |
+
while True:
|
50 |
+
try:
|
51 |
+
value = self.stop_signal # value looks unused in pycharm, not true
|
52 |
+
if self.do_stop:
|
53 |
+
print("hit stop", flush=True)
|
54 |
+
# could raise or break, maybe best to raise and make parent see if any exception in thread
|
55 |
+
raise StopIteration()
|
56 |
+
# break
|
57 |
+
value = self.text_queue.get(block=self.block, timeout=self.timeout)
|
58 |
+
break
|
59 |
+
except queue.Empty:
|
60 |
+
time.sleep(0.01)
|
61 |
+
if value == self.stop_signal:
|
62 |
+
raise StopIteration()
|
63 |
+
else:
|
64 |
+
return value
|