pseudotensor commited on
Commit
eeb7ca1
1 Parent(s): 3a1a60a

Update with h2oGPT hash 13a8343d2a96885985bda8c4480bbb23cf55bb9b

Browse files
LICENSE DELETED
@@ -1 +0,0 @@
1
- ../../LICENSE
 
 
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
client_test.py DELETED
@@ -1 +0,0 @@
1
- ../../client_test.py
 
 
client_test.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client test.
3
+
4
+ Run server:
5
+
6
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
7
+
8
+ NOTE: For private models, add --use-auth_token=True
9
+
10
+ NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
+ Currently, this will force model to be on a single GPU.
12
+
13
+ Then run this client as:
14
+
15
+ python client_test.py
16
+
17
+
18
+
19
+ For HF spaces:
20
+
21
+ HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
22
+
23
+ Result:
24
+
25
+ Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
26
+ {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
27
+
28
+
29
+ For demo:
30
+
31
+ HOST="https://gpt.h2o.ai" python client_test.py
32
+
33
+ Result:
34
+
35
+ Loaded as API: https://gpt.h2o.ai ✔
36
+ {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
37
+
38
+ NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
39
+
40
+ {'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
41
+
42
+
43
+ """
44
+ import ast
45
+ import time
46
+ import os
47
+ import markdown # pip install markdown
48
+ import pytest
49
+ from bs4 import BeautifulSoup # pip install beautifulsoup4
50
+
51
+ from enums import DocumentChoices
52
+
53
+ debug = False
54
+
55
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
56
+
57
+
58
+ def get_client(serialize=True):
59
+ from gradio_client import Client
60
+
61
+ client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize)
62
+ if debug:
63
+ print(client.view_api(all_endpoints=True))
64
+ return client
65
+
66
+
67
+ def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
+ max_new_tokens=50,
69
+ top_k_docs=3,
70
+ langchain_mode='Disabled'):
71
+ from collections import OrderedDict
72
+ kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
73
+ iinput='', # only for chat=True
74
+ context='',
75
+ # streaming output is supported, loops over and outputs each generation in streaming mode
76
+ # but leave stream_output=False for simple input/output mode
77
+ stream_output=stream_output,
78
+ prompt_type=prompt_type,
79
+ prompt_dict='',
80
+ temperature=0.1,
81
+ top_p=0.75,
82
+ top_k=40,
83
+ num_beams=1,
84
+ max_new_tokens=max_new_tokens,
85
+ min_new_tokens=0,
86
+ early_stopping=False,
87
+ max_time=20,
88
+ repetition_penalty=1.0,
89
+ num_return_sequences=1,
90
+ do_sample=True,
91
+ chat=chat,
92
+ instruction_nochat=prompt if not chat else '',
93
+ iinput_nochat='', # only for chat=False
94
+ langchain_mode=langchain_mode,
95
+ top_k_docs=top_k_docs,
96
+ chunk=True,
97
+ chunk_size=512,
98
+ document_choice=[DocumentChoices.All_Relevant.name],
99
+ )
100
+ if chat:
101
+ # add chatbot output on end. Assumes serialize=False
102
+ kwargs.update(dict(chatbot=[]))
103
+
104
+ return kwargs, list(kwargs.values())
105
+
106
+
107
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
108
+ def test_client_basic():
109
+ return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
110
+
111
+
112
+ def run_client_nochat(prompt, prompt_type, max_new_tokens):
113
+ kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
114
+
115
+ api_name = '/submit_nochat'
116
+ client = get_client(serialize=True)
117
+ res = client.predict(
118
+ *tuple(args),
119
+ api_name=api_name,
120
+ )
121
+ print("Raw client result: %s" % res, flush=True)
122
+ res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
123
+ response=md_to_text(res))
124
+ print(res_dict)
125
+ return res_dict
126
+
127
+
128
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
129
+ def test_client_basic_api():
130
+ return run_client_nochat_api(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
131
+
132
+
133
+ def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
134
+ kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
135
+
136
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
137
+ client = get_client(serialize=True)
138
+ res = client.predict(
139
+ str(dict(kwargs)),
140
+ api_name=api_name,
141
+ )
142
+ print("Raw client result: %s" % res, flush=True)
143
+ res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
144
+ response=md_to_text(ast.literal_eval(res)['response']),
145
+ sources=ast.literal_eval(res)['sources'])
146
+ print(res_dict)
147
+ return res_dict
148
+
149
+
150
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
151
+ def test_client_basic_api_lean():
152
+ return run_client_nochat_api_lean(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
153
+
154
+
155
+ def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
156
+ kwargs = dict(instruction_nochat=prompt)
157
+
158
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
159
+ client = get_client(serialize=True)
160
+ res = client.predict(
161
+ str(dict(kwargs)),
162
+ api_name=api_name,
163
+ )
164
+ print("Raw client result: %s" % res, flush=True)
165
+ res_dict = dict(prompt=kwargs['instruction_nochat'],
166
+ response=md_to_text(ast.literal_eval(res)['response']),
167
+ sources=ast.literal_eval(res)['sources'])
168
+ print(res_dict)
169
+ return res_dict
170
+
171
+
172
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
173
+ def test_client_basic_api_lean_morestuff():
174
+ return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
175
+
176
+
177
+ def run_client_nochat_api_lean_morestuff(prompt, prompt_type, max_new_tokens):
178
+ kwargs = dict(
179
+ instruction='',
180
+ iinput='',
181
+ context='',
182
+ stream_output=False,
183
+ prompt_type='human_bot',
184
+ temperature=0.1,
185
+ top_p=0.75,
186
+ top_k=40,
187
+ num_beams=1,
188
+ max_new_tokens=256,
189
+ min_new_tokens=0,
190
+ early_stopping=False,
191
+ max_time=20,
192
+ repetition_penalty=1.0,
193
+ num_return_sequences=1,
194
+ do_sample=True,
195
+ chat=False,
196
+ instruction_nochat=prompt,
197
+ iinput_nochat='',
198
+ langchain_mode='Disabled',
199
+ top_k_docs=4,
200
+ document_choice=['All'],
201
+ )
202
+
203
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
204
+ client = get_client(serialize=True)
205
+ res = client.predict(
206
+ str(dict(kwargs)),
207
+ api_name=api_name,
208
+ )
209
+ print("Raw client result: %s" % res, flush=True)
210
+ res_dict = dict(prompt=kwargs['instruction_nochat'],
211
+ response=md_to_text(ast.literal_eval(res)['response']),
212
+ sources=ast.literal_eval(res)['sources'])
213
+ print(res_dict)
214
+ return res_dict
215
+
216
+
217
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
218
+ def test_client_chat():
219
+ return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50,
220
+ langchain_mode='Disabled')
221
+
222
+
223
+ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
224
+ client = get_client(serialize=False)
225
+
226
+ kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
227
+ max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
228
+ return run_client(client, prompt, args, kwargs)
229
+
230
+
231
+ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
232
+ res = client.predict(*tuple(args), api_name='/instruction')
233
+ args[-1] += [res[-1]]
234
+
235
+ res_dict = kwargs
236
+ res_dict['prompt'] = prompt
237
+ if not kwargs['stream_output']:
238
+ res = client.predict(*tuple(args), api_name='/instruction_bot')
239
+ res_dict['response'] = res[0][-1][1]
240
+ print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
241
+ return res_dict, client
242
+ else:
243
+ job = client.submit(*tuple(args), api_name='/instruction_bot')
244
+ res1 = ''
245
+ while not job.done():
246
+ outputs_list = job.communicator.job.outputs
247
+ if outputs_list:
248
+ res = job.communicator.job.outputs[-1]
249
+ res1 = res[0][-1][-1]
250
+ res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
251
+ print(res1)
252
+ time.sleep(0.1)
253
+ full_outputs = job.outputs()
254
+ if verbose:
255
+ print('job.outputs: %s' % str(full_outputs))
256
+ # ensure get ending to avoid race
257
+ # -1 means last response if streaming
258
+ # 0 means get text_output, ignore exception_text
259
+ # 0 means get list within text_output that looks like [[prompt], [answer]]
260
+ # 1 means get bot answer, so will have last bot answer
261
+ res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
262
+ return res_dict, client
263
+
264
+
265
+ def md_to_text(md, do_md_to_text=True):
266
+ if not do_md_to_text:
267
+ return md
268
+ assert md is not None, "Markdown is None"
269
+ html = markdown.markdown(md)
270
+ soup = BeautifulSoup(html, features='html.parser')
271
+ return soup.get_text()
272
+
273
+
274
+ if __name__ == '__main__':
275
+ test_client_basic()
276
+ test_client_basic_api()
277
+ test_client_basic_api_lean()
278
+ test_client_basic_api_lean_morestuff()
create_data.py DELETED
@@ -1 +0,0 @@
1
- ../../create_data.py
 
 
create_data.py ADDED
@@ -0,0 +1,1809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset creation tools.
3
+
4
+ Keep to-level imports clean of non-trivial imports for specific tools,
5
+ because this file is imported for various purposes
6
+ """
7
+
8
+ import ast
9
+ import concurrent.futures
10
+ import contextlib
11
+ import hashlib
12
+ import json
13
+ import os
14
+ import shutil
15
+ import signal
16
+ import sys
17
+ import traceback
18
+ from concurrent.futures import ProcessPoolExecutor
19
+
20
+ import psutil
21
+ import pytest
22
+ import pandas as pd
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+
26
+ from utils import flatten_list, remove
27
+
28
+
29
+ def parse_rst_file(filepath):
30
+ with open(filepath, 'r') as f:
31
+ input_data = f.read()
32
+ settings_overrides = {'initial_header_level': 2}
33
+ from docutils import core
34
+ document = core.publish_doctree(
35
+ source=input_data,
36
+ source_path=filepath,
37
+ settings_overrides=settings_overrides,
38
+ )
39
+ qa_pairs = []
40
+ current_section = None
41
+ current_question = ""
42
+ current_answer = ""
43
+ for node in document.traverse():
44
+ if node.__class__.__name__ == 'section':
45
+ current_section = ""
46
+ elif current_section is not None:
47
+ if node.__class__.__name__ == 'Text':
48
+ if node.astext()[-1] == "?":
49
+ if current_question:
50
+ qa_pairs.append((current_question, current_answer))
51
+ current_question = node.astext()
52
+ current_answer = ""
53
+ else:
54
+ current_answer += node.astext()
55
+ if current_answer:
56
+ qa_pairs.append((current_question, current_answer))
57
+ return {k: v for k, v in qa_pairs}
58
+
59
+
60
+ def test_scrape_dai_docs():
61
+ home = os.path.expanduser('~')
62
+ file = os.path.join(home, 'h2oai/docs/faq.rst')
63
+ qa_pairs = parse_rst_file(file)
64
+ prompt_type = 'human_bot'
65
+ from prompter import prompt_types
66
+ assert prompt_type in prompt_types
67
+ save_thing = [{"instruction": k, "output": v, 'prompt_type': prompt_type} for k, v in qa_pairs.items()]
68
+ output_file = "dai_faq.json"
69
+ with open(output_file, "wt") as f:
70
+ f.write(json.dumps(save_thing, indent=2))
71
+
72
+
73
+ def test_scrape_dai_docs_all():
74
+ """
75
+ pytest create_data.py::test_scrape_dai_docs_all
76
+ """
77
+ import glob
78
+ import nltk
79
+ nltk.download('punkt')
80
+ dd = {}
81
+ np.random.seed(1234)
82
+ home = os.path.expanduser('~')
83
+ files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst")))
84
+ np.random.shuffle(files)
85
+ val_count = int(0.05 * len(files))
86
+ train_files = files[val_count:]
87
+ valid_files = files[:val_count]
88
+ things = [
89
+ ("dai_docs.train.json", train_files),
90
+ ("dai_docs.valid.json", valid_files)
91
+ ]
92
+ for LEN in [100, 200, 500]:
93
+ for output_file, ff in things:
94
+ if output_file not in dd:
95
+ dd[output_file] = []
96
+ for f in ff:
97
+ with open(f) as input:
98
+ blob = input.read()
99
+ blob = blob.replace("~~", "")
100
+ blob = blob.replace("==", "")
101
+ blob = blob.replace("''", "")
102
+ blob = blob.replace("--", "")
103
+ blob = blob.replace("**", "")
104
+ dd[output_file].extend(get_sentences(blob, length=LEN))
105
+ for output_file, _ in things:
106
+ save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in dd[output_file]]
107
+ with open(output_file, "wt") as f:
108
+ f.write(json.dumps(save_thing, indent=2))
109
+
110
+
111
+ def get_sentences(blob, length):
112
+ """
113
+ break-up input text into sentences and then output list of sentences of about length in size
114
+ :param blob:
115
+ :param length:
116
+ :return:
117
+ """
118
+ import nltk
119
+ nltk.download('punkt')
120
+ from nltk.tokenize import sent_tokenize
121
+ sentences = sent_tokenize(blob)
122
+ my_sentences = []
123
+ my_string = ""
124
+ for sentence in sentences:
125
+ if len(my_string) + len(sentence) <= length:
126
+ if my_string:
127
+ my_string += " " + sentence
128
+ else:
129
+ my_string = sentence
130
+ else:
131
+ my_sentences.append(my_string)
132
+ my_string = ""
133
+ return my_sentences or [my_string]
134
+
135
+
136
+ def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
137
+ """
138
+ Only supported if have access to source code or HF token for HF spaces and from_hf=True
139
+ :param path:
140
+ :param dst:
141
+ :param from_hf:
142
+ :return:
143
+ """
144
+
145
+ home = os.path.expanduser('~')
146
+
147
+ if from_hf:
148
+ # assumes
149
+ from huggingface_hub import hf_hub_download
150
+ # True for case when locally already logged in with correct token, so don't have to set key
151
+ token = os.getenv('HUGGINGFACE_API_TOKEN', True)
152
+ path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.zip', token=token, repo_type='dataset')
153
+ path = 'h2oai'
154
+ import zipfile
155
+ with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
156
+ zip_ref.extractall(path)
157
+ path = os.path.join(path, 'docs/**/*')
158
+
159
+ if path is None:
160
+ if os.path.isdir(os.path.join(home, 'h2oai')):
161
+ path = os.path.join(home, "h2oai/docs/**/*")
162
+ else:
163
+ assert os.path.isdir(os.path.join(home, 'h2oai.superclean')), '%s does not exist' % path
164
+ path = os.path.join(home, "h2oai.superclean/docs/**/*")
165
+ import glob
166
+ files = list(glob.glob(path, recursive=True))
167
+
168
+ # pandoc can't find include files
169
+
170
+ remove(dst)
171
+ os.makedirs(dst)
172
+
173
+ # copy full tree, for absolute paths in rst
174
+ for fil in files:
175
+ if os.path.isfile(fil):
176
+ shutil.copy(fil, dst)
177
+
178
+ # hack for relative path
179
+ scorers_dir = os.path.join(dst, 'scorers')
180
+ makedirs(scorers_dir)
181
+ for fil in glob.glob(os.path.join(dst, '*.frag')):
182
+ shutil.copy(fil, scorers_dir)
183
+
184
+ return dst
185
+
186
+
187
+ def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30):
188
+ # account for sequence length (context window) including prompt and input and output
189
+
190
+ # os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
191
+ import pypandoc
192
+ basedir = os.path.abspath(os.getcwd())
193
+
194
+ outputs = []
195
+ for fil in files:
196
+ os.chdir(basedir)
197
+ os.chdir(os.path.dirname(fil))
198
+ fil = os.path.basename(fil)
199
+ print("Processing %s" % fil, flush=True)
200
+ # out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x,
201
+ # context, csljson, docbook, docbook4, docbook5, docx, dokuwiki,
202
+ # dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml,
203
+ # ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira,
204
+ # json, latex, man,
205
+ # markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict,
206
+ # mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx,
207
+ # revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki
208
+ out_format = 'plain'
209
+ # avoid extra new lines injected into text
210
+ extra_args = ['--wrap=preserve', '--resource path="%s" % dst']
211
+
212
+ plain_list = []
213
+ try:
214
+ # valid for expert settings
215
+ input_rst = pypandoc.convert_file(fil, 'rst')
216
+ input_list = input_rst.split('\n``')
217
+ for input_subrst in input_list:
218
+ input_plain = pypandoc.convert_text(input_subrst, format='rst', to='plain')
219
+ plain_list.append([input_plain, fil])
220
+ except Exception as e:
221
+ print("file exception: %s %s" % (fil, str(e)), flush=True)
222
+
223
+ if not plain_list:
224
+ # if failed to process as pieces of rst, then
225
+ output = pypandoc.convert_file(fil, out_format, extra_args=extra_args, format='rst')
226
+ outputs1 = get_sentences(output, length=max_len)
227
+ for oi, output in enumerate(outputs1):
228
+ output = output.replace('\n\n', '\n')
229
+ plain_list.append([output, fil])
230
+ outputs.extend(plain_list)
231
+
232
+ # report:
233
+ # [print(len(x)) for x in outputs]
234
+
235
+ # deal with blocks longer than context size (sequence length) of 2048
236
+ new_outputs = []
237
+ num_truncated = 0
238
+ num_orig = len(outputs)
239
+ for output, fil in outputs:
240
+ if len(output) < max_len:
241
+ new_outputs.append([output, fil])
242
+ continue
243
+ outputs1 = get_sentences(output, length=max_len)
244
+ for oi, output1 in enumerate(outputs1):
245
+ output1 = output1.replace('\n\n', '\n')
246
+ new_outputs.append([output1, fil])
247
+ num_truncated += 1
248
+ print('num_orig: %s num_truncated: %s' % (num_orig, num_truncated), flush=True)
249
+
250
+ new_outputs = [[k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len]
251
+
252
+ return new_outputs
253
+
254
+
255
+ def test_scrape_dai_docs_all_pandoc():
256
+ """
257
+ pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc
258
+ :return:
259
+ """
260
+
261
+ dst = setup_dai_docs()
262
+
263
+ import glob
264
+ files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
265
+
266
+ basedir = os.path.abspath(os.getcwd())
267
+ new_outputs = rst_to_outputs(files)
268
+ os.chdir(basedir)
269
+
270
+ remove(dst)
271
+ save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in new_outputs]
272
+ output_file = "dai_docs.train_cleaned.json"
273
+ with open(output_file, "wt") as f:
274
+ f.write(json.dumps(save_thing, indent=2))
275
+
276
+
277
+ def test_config_to_json():
278
+ """
279
+ Needs to run from Driverless AI source directory.
280
+ E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/
281
+ :return:
282
+ """
283
+ try:
284
+ # Arrange
285
+ import json
286
+ from h2oaicore.systemutils import config
287
+ toml_list = []
288
+ for k, v in config.get_meta_dict().items():
289
+ title = (v.title + ": ") if v.title else ''
290
+ comment = v.comment or ''
291
+ if not (title or comment):
292
+ continue
293
+ toml_list.extend(
294
+ [
295
+ {
296
+ 'prompt_type': 'plain',
297
+ 'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
298
+ "\n", ""),
299
+ },
300
+ {
301
+ 'prompt_type': 'plain',
302
+ 'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
303
+ "\n", ""),
304
+ },
305
+ {
306
+ 'prompt_type': 'plain',
307
+ 'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace(
308
+ "\n", ""),
309
+ } if title and comment else None,
310
+ {
311
+ 'prompt_type': 'human_bot',
312
+ 'instruction': f'Explain the following expert setting for Driverless AI',
313
+ 'input': f"{k}",
314
+ 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
315
+ },
316
+ {
317
+ 'prompt_type': 'human_bot',
318
+ 'instruction': f'Explain the following expert setting for Driverless AI',
319
+ 'input': f"{k}",
320
+ 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
321
+ },
322
+ {
323
+ 'prompt_type': 'human_bot',
324
+ 'instruction': f'Explain the following expert setting for Driverless AI',
325
+ 'input': f"{k.replace('_', ' ')}",
326
+ 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
327
+ },
328
+ {
329
+ 'prompt_type': 'human_bot',
330
+ 'instruction': f'Explain the following expert setting for Driverless AI',
331
+ 'input': f"{title}",
332
+ 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
333
+ },
334
+ {
335
+ 'prompt_type': 'human_bot',
336
+ 'instruction': f'Provide a short explanation of the expert setting {k}',
337
+ 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
338
+ },
339
+ {
340
+ 'prompt_type': 'human_bot',
341
+ 'instruction': f'Provide a detailed explanation of the expert setting {k}',
342
+ 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
343
+ },
344
+ ]
345
+ )
346
+ toml_list = [x for x in toml_list if x]
347
+ with open("config.json", "wt") as f:
348
+ f.write(json.dumps(toml_list, indent=2))
349
+ except Exception as e:
350
+ print("Exception: %s" % str(e), flush=True)
351
+
352
+
353
+ def copy_tree(src, dst, follow_symlink=False):
354
+ makedirs(dst, exist_ok=True)
355
+ for (path, dirs, files) in os.walk(src, followlinks=follow_symlink):
356
+ new_path = path.replace(src, dst)
357
+ makedirs(new_path, exist_ok=True)
358
+ for file in files:
359
+ filename = os.path.join(path, file)
360
+ new_filename = os.path.join(new_path, file)
361
+ # print("%s -> %s" % (filename, new_filename))
362
+ try:
363
+ atomic_copy(filename, new_filename)
364
+ except FileNotFoundError:
365
+ pass
366
+
367
+
368
+ def atomic_move(src, dst):
369
+ try:
370
+ shutil.move(src, dst)
371
+ except (shutil.Error, FileExistsError):
372
+ pass
373
+ remove(src)
374
+
375
+
376
+ def atomic_copy(src=None, dst=None, with_permissions=True):
377
+ if os.path.isfile(dst):
378
+ return
379
+ import uuid
380
+ my_uuid = uuid.uuid4()
381
+ dst_tmp = dst + str(my_uuid)
382
+ makedirs(os.path.dirname(dst), exist_ok=True)
383
+ if with_permissions:
384
+ shutil.copy(src, dst_tmp)
385
+ else:
386
+ shutil.copyfile(src, dst_tmp)
387
+ atomic_move(dst_tmp, dst)
388
+ remove(dst_tmp)
389
+
390
+
391
+ def makedirs(path, exist_ok=True):
392
+ """
393
+ Avoid some inefficiency in os.makedirs()
394
+ :param path:
395
+ :param exist_ok:
396
+ :return:
397
+ """
398
+ if os.path.isdir(path) and os.path.exists(path):
399
+ assert exist_ok, "Path already exists"
400
+ return path
401
+ os.makedirs(path, exist_ok=exist_ok)
402
+
403
+
404
+ ## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json
405
+ ## Turn into simple instruct prompt type. No context/previous conversations.
406
+ def test_prep_instruct_vicuna():
407
+ from datasets import load_dataset
408
+ filename = 'ShareGPT_unfiltered_cleaned_split.json'
409
+ if not os.path.exists(filename):
410
+ os.system(
411
+ 'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
412
+ data = load_dataset("json", data_files={"train": filename})["train"]
413
+ training_rows = []
414
+ for i in range(data.num_rows):
415
+ conversations = data[i]['conversations']
416
+ assert isinstance(conversations, list), conversations
417
+ convo = ""
418
+ for j, conv in enumerate(conversations):
419
+ # Get ready for generate.py prompt_type=human_bot
420
+ # But train with prompt_type=plain
421
+ if conv['from'] == 'human':
422
+ FROM = '<human>: '
423
+ elif conv['from'] == 'gpt':
424
+ FROM = '<bot>: '
425
+ convo += f"{FROM}" + conv['value'] + "\n"
426
+ if convo:
427
+ training_rows.append(dict(input=convo))
428
+ with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
429
+ f.write(json.dumps(training_rows, indent=2))
430
+
431
+
432
+ POSTFIX = ".generate_human_bot.train_plain.json"
433
+
434
+ # https://bair.berkeley.edu/blog/2023/04/03/koala/
435
+ OIG_DATASETS = [
436
+ "unified_chip2.jsonl",
437
+ "unified_grade_school_math_instructions.jsonl",
438
+ "unified_poetry_2_song.jsonl",
439
+ "unified_plot_screenplay_books_dialog.jsonl",
440
+ ]
441
+
442
+ # hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4
443
+ ALL_OIG_DATASETS = ['unified_abstract_infill.jsonl',
444
+ 'unified_basic.jsonl',
445
+ 'unified_canadian_parliament.jsonl',
446
+ 'unified_chip2.jsonl',
447
+ 'unified_conv_finqa.jsonl',
448
+ 'unified_cuad.jsonl',
449
+ 'unified_essays.jsonl',
450
+ 'unified_flan.jsonl.gz',
451
+ 'unified_grade_school_math_instructions.jsonl',
452
+ 'unified_hc3_human.jsonl',
453
+ 'unified_image_prompts_instructions.jsonl',
454
+ 'unified_joke_explanations.jsonl',
455
+ 'unified_mathqa_flanv2_kojma_cot.jsonl',
456
+ 'unified_merged_code_xp3.jsonl',
457
+ 'unified_multi_news.jsonl',
458
+ 'unified_multi_sum.jsonl',
459
+ 'unified_ni.jsonl.gz',
460
+ 'unified_nq.jsonl',
461
+ 'unified_openai_summarize_tldr.jsonl',
462
+ 'unified_oscar_en_sample_dialog.jsonl',
463
+ 'unified_p3.jsonl.gz',
464
+ 'unified_plot_screenplay_books_dialog.jsonl',
465
+ 'unified_poetry_2_song.jsonl',
466
+ 'unified_poetry_instructions.jsonl',
467
+ 'unified_rallio_safety_and_prosocial.jsonl',
468
+ 'unified_rallio_soda_upgraded_2048.jsonl',
469
+ 'unified_soda_dialog.jsonl',
470
+ 'unified_sqlv1.jsonl',
471
+ 'unified_sqlv2.jsonl',
472
+ 'unified_squad_v2.jsonl',
473
+ 'unified_squad_v2_more_neg.jsonl',
474
+ 'unified_ul2_plus_oscar_en_sample_dialog.jsonl',
475
+ 'unified_unifiedskg_instructions.jsonl',
476
+ 'unified_unnatural_instructions.jsonl',
477
+ 'unified_xp3_sample.jsonl']
478
+
479
+ useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
480
+ 'unified_chip2.jsonl.parquet',
481
+ 'unified_cuad.jsonl.parquet',
482
+ 'unified_essays.jsonl.parquet',
483
+ 'unified_flan.jsonl.gz.parquet',
484
+ 'unified_grade_school_math_instructions.jsonl.parquet',
485
+ 'unified_hc3_human.jsonl.parquet',
486
+ 'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
487
+ 'unified_merged_code_xp3.jsonl.parquet',
488
+ 'unified_multi_news.jsonl.parquet',
489
+ # 'unified_multi_sum.jsonl.parquet'
490
+ 'unified_ni.jsonl.gz.parquet',
491
+ 'unified_openai_summarize_tldr.jsonl.parquet',
492
+ # 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
493
+ 'unified_plot_screenplay_books_dialog.jsonl.parquet',
494
+ 'unified_soda_dialog.jsonl.parquet',
495
+ 'unified_unnatural_instructions.jsonl.parquet',
496
+ ]
497
+
498
+
499
+ @pytest.mark.parametrize("filename", OIG_DATASETS)
500
+ def test_get_small_sample_oig_data(filename):
501
+ if not os.path.exists(filename):
502
+ os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
503
+ import json
504
+ rows = []
505
+ with open(filename, "r") as f:
506
+ for line in f.readlines():
507
+ row = json.loads(line)
508
+ rows.append(dict(input=row["text"]))
509
+ with open(filename + POSTFIX, "w") as f:
510
+ f.write(json.dumps(rows, indent=2))
511
+
512
+
513
+ @pytest.mark.parametrize("filename", ALL_OIG_DATASETS)
514
+ def test_download_useful_data_as_parquet(filename):
515
+ dest_file = filename + '.parquet'
516
+ if dest_file not in useful_oig_files:
517
+ pytest.skip('file declared not useful')
518
+ if not os.path.exists(filename):
519
+ os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
520
+ if not os.path.exists(dest_file):
521
+ df = pd.read_json(path_or_buf=filename, lines=True)
522
+ df.to_parquet(dest_file, index=False)
523
+
524
+
525
+ def test_merge_shuffle_small_sample_oig_data():
526
+ np.random.seed(1234)
527
+ rows = []
528
+ for filename in OIG_DATASETS:
529
+ with open(filename + POSTFIX, "r") as f:
530
+ rows.extend(json.loads(f.read()))
531
+ np.random.shuffle(rows)
532
+ with open("merged_shuffled_OIG_%s.json" % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], "w") as f:
533
+ f.write(json.dumps(rows, indent=2))
534
+
535
+
536
+ def test_join_jsons():
537
+ files = ['config.json'] * 1 + \
538
+ ['dai_docs.train_cleaned.json'] * 2 + \
539
+ ['dai_faq.json'] * 3
540
+ print(files)
541
+ lst = []
542
+ [lst.extend(json.load(open(fil, 'rt'))) for fil in files]
543
+ print(len(lst))
544
+ json.dump(lst, open("merged.json", "wt"), indent=2)
545
+
546
+
547
+ @pytest.mark.parametrize("filename", ['Anthropic/hh-rlhf'])
548
+ def test_make_rlhf_good_data(filename):
549
+ from datasets import load_dataset
550
+ rows = load_dataset(filename)["train"]["chosen"]
551
+ new_rows = []
552
+ for row in rows:
553
+ if row[:2] == "\n\n":
554
+ row = row[2:]
555
+ row = row.replace("Human: ", "<human>: ")
556
+ row = row.replace("Assistant: ", "<bot>: ")
557
+ new_rows.append(dict(input=row))
558
+ with open(filename.replace("/", "_") + POSTFIX, "w") as f:
559
+ f.write(json.dumps(new_rows, indent=2))
560
+
561
+
562
+ def test_show_prompts():
563
+ files = ['config.json'] * 1 + \
564
+ ['dai_docs.train_cleaned.json'] * 1 + \
565
+ ['dai_faq.json'] * 1
566
+ file_points = [json.load(open(fil, 'rt')) for fil in files]
567
+ from prompter import generate_prompt
568
+ for data_points in file_points:
569
+ for data_point in data_points:
570
+ print(generate_prompt(data_point, 'plain', '', False, False)[0])
571
+
572
+
573
+ def test_get_open_datasets():
574
+ # HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter
575
+ open_tags = ['license:Apache License 2.0',
576
+ 'license:mit',
577
+ 'license:apache',
578
+ 'license:apache2',
579
+ 'license:apache-2.0',
580
+ 'license:bsd',
581
+ 'license:bsd-2-clause',
582
+ 'license:bsd-3-clause',
583
+ 'license:bsd-3-clause-clear',
584
+ 'license:lgpl-2.1',
585
+ 'license:lgpl-3.0',
586
+ 'license:lgpl-lr',
587
+ 'license:lgpl',
588
+ 'license:openrail++',
589
+ 'license:openrail',
590
+ 'license:bigscience-bloom-rail-1.0',
591
+ # 'license:agpl-3.0',
592
+ 'license:other',
593
+ 'license:unknown',
594
+ # 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
595
+ # Attribution required:
596
+ 'license:odc-by',
597
+ 'license:cc-by-4.0',
598
+ 'license:cc-by-3.0',
599
+ 'license:cc-by-2.0',
600
+ 'license:cc-by-2.5',
601
+ # 'license:cc-by-sa-4.0', # would require same license
602
+ 'license:odbl',
603
+ 'license:pddl',
604
+ 'license:ms-pl',
605
+ 'license:zlib',
606
+ ]
607
+ # bad license: cc-by-nc-4.0
608
+
609
+ from huggingface_hub import list_datasets
610
+ datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
611
+ datasets += [x for x in list_datasets(author='openai')]
612
+ # check all:
613
+ all_license_tags = set(flatten_list([[y for y in x.tags if 'license' in y] for x in datasets]))
614
+ print(len(all_license_tags))
615
+ open_datasets = [x for x in datasets if any([y in x.tags for y in open_tags]) or 'license:' not in str(x.tags)]
616
+ print('open_datasets', len(open_datasets))
617
+ all_task_tags = set(flatten_list([[y for y in x.tags if 'task' in y] for x in open_datasets]))
618
+ print('all_task_tags', len(all_task_tags))
619
+ excluded_tags = ['image', 'hate', 'tabular', 'table-', 'classification', 'retrieval',
620
+ 'translation', 'identification', 'object', 'mask', 'to-text',
621
+ 'face-detection', 'audio', 'voice', 'reinforcement', 'depth-est',
622
+ 'forecasting', 'parsing', 'visual', 'speech', 'multiple-choice',
623
+ 'slot-filling', 'irds/argsme', '-scoring', 'other', 'graph-ml',
624
+ 'feature-extraction', 'keyword-spotting',
625
+ 'coreference-resolution', 'segmentation',
626
+ 'word-sense-disambiguation',
627
+ 'lemmatization']
628
+ task_tags = [x.replace('task_categories:', '').replace('task_ids:', '')
629
+ for x in all_task_tags if not any([y in x for y in
630
+ excluded_tags])]
631
+ print('task_tags', len(task_tags))
632
+ # str(x.tags) to catch any pattern match to anything in list
633
+ open_tasked_datasets = [x for x in open_datasets if
634
+ any([y in str([x for x in x.tags if 'task' in x]) for y in task_tags]) and
635
+ not any([y in str([x for x in x.tags if 'task' in x]) for y in excluded_tags]) or
636
+ 'task_categories' not in str(x.tags) and 'task_ids' not in str(x.tags)]
637
+ open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled]
638
+ open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated]
639
+ open_tasked_datasets = [x for x in open_tasked_datasets if not x.private]
640
+ print('open_tasked_datasets', len(open_tasked_datasets))
641
+ sizes = list(set(flatten_list([[(y, x.id) for y in x.tags if 'size' in y] for x in open_tasked_datasets])))
642
+ languages = list(set(flatten_list([[(y, x.id) for y in x.tags if 'language:' in y] for x in open_tasked_datasets])))
643
+ open_english_tasked_datasets = [x for x in open_tasked_datasets if
644
+ 'language:' not in str(x.tags) or
645
+ 'language:en' in str(x.tags)]
646
+ small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
647
+ 'n<1K' in str(x.tags) or
648
+ '1K<n<10K' in str(x.tags) or
649
+ '1K0<n<100K' in str(x.tags) or
650
+ '100K<n<1M' in str(x.tags) or
651
+ 'size_category' not in str(x.tags)
652
+ ]
653
+ # 'aeslc' : email_body, subject -> summarization?
654
+ # load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
655
+ ids = [x.id for x in small_open_english_tasked_datasets]
656
+
657
+ # sanity checks
658
+ # https://bair.berkeley.edu/blog/2023/04/03/koala/
659
+ assert 'alespalla/chatbot_instruction_prompts' in ids
660
+ assert 'laion/OIG' in ids
661
+ assert 'openai/webgpt_comparisons' in ids
662
+ assert 'openai/summarize_from_feedback' in ids
663
+ assert 'Anthropic/hh-rlhf' in ids
664
+
665
+ # useful but not allowed for commercial purposes:
666
+ # https://huggingface.co/datasets/squad
667
+
668
+ print('open_english_tasked_datasets: ', ids, flush=True)
669
+
670
+ exclude_ids = ['allenai/nllb', # translation only
671
+ 'hf-internal-testing/fixtures_image_utils', # testing
672
+ 'allenai/c4', # search-url
673
+ 'agemagician/uniref50', # unknown
674
+ 'huggingface-course/documentation-images', # images
675
+ 'smilegate-ai/kor_unsmile', # korean
676
+ 'MohamedRashad/ChatGPT-prompts', # ChatGPT/LearnGPT/https://www.emergentmind.com/
677
+ 'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
678
+ 'Jeska/vaccinchat', # not useful
679
+ 'alespalla/chatbot_instruction_prompts', # mixes alpaca
680
+ 'allenai/prosocial-dialog',
681
+ # already exlucded, but wrongly in other datasets that say more permissive license
682
+ 'AlekseyKorshuk/persona-chat', # low quality
683
+ 'bavard/personachat_truecased', # low quality
684
+ 'adamlin/daily_dialog', # medium quality conversations
685
+ 'adamlin/FewShotWoz', # low quality
686
+ 'benjaminbeilharz/better_daily_dialog', # low quality
687
+ 'benjaminbeilharz/daily_dialog_w_turn_templates', # low
688
+ 'benjaminbeilharz/empathetic_dialogues_for_lm', # low
689
+ 'GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915', # NA
690
+ 'ia-bentebib/conv_ai_2_fr', # low fr
691
+ 'ia-bentebib/daily_dialog_fr', # low fr
692
+ 'ia-bentebib/dialog_re_fr', # low fr
693
+ 'ia-bentebib/empathetic_dialogues_fr', # low fr
694
+ 'roskoN/dailydialog', # low
695
+ 'VadorMazer/skyrimdialogstest', # low
696
+ 'bigbio/med_qa', # med specific Q/A
697
+ 'biu-nlp/qa_srl2018', # low quality Q/A
698
+ 'biu-nlp/qa_discourse', # low quality Q/A
699
+ 'iarfmoose/qa_evaluator', # low quality Q/A
700
+ 'jeopardy', # low quality Q/A -- no reasoning
701
+ 'narrativeqa', # low quality Q/A
702
+ 'nomic-ai/gpt4all_prompt_generations', # bad license
703
+ 'nomic-ai/gpt4all_prompt_generations_with_p3', # bad license
704
+ 'HuggingFaceH4/alpaca', # bad license
705
+ 'tatsu-lab/alpaca', # ToS breaking
706
+ 'yahma/alpaca-cleaned', # ToS breaking
707
+ 'Hello-SimpleAI/HC3', # bad license
708
+ 'glue', # no reasoning QA
709
+ 'sahil2801/CodeAlpaca-20k', # bad license
710
+ 'Short-Answer-Feedback/saf_communication_networks_english', # long Q, medium A
711
+ ]
712
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if x.id not in exclude_ids]
713
+ # some ids clearly speech related
714
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
715
+ # HF testing
716
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
717
+ 'hf-internal-testing' not in x.id]
718
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
719
+ 'chinese' not in x.id]
720
+
721
+ sorted_small_open_english_tasked_datasets = sorted([(x.downloads, x) for x in small_open_english_tasked_datasets],
722
+ key=lambda x: x[0], reverse=True)
723
+
724
+ # NOTES:
725
+ # Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log
726
+ # See what needs config passed and add:
727
+ # grep 'load_dataset(' getdata9.log|grep -v data_id|less -S
728
+ # grep "pip install" getdata9.log
729
+ # NOTE: Some datasets have default config, but others are there. Don't know how to access them.
730
+
731
+ """
732
+ https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
733
+ https://github.com/mahnazkoupaee/WikiHow-Dataset
734
+ https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
735
+ https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
736
+ """
737
+
738
+ """
739
+ # some ambiguous or non-commercial datasets
740
+ https://github.com/PhoebusSi/alpaca-CoT
741
+ """
742
+
743
+ timeout = 3 * 60
744
+ # laion/OIG takes longer
745
+ for num_downloads, dataset in sorted_small_open_english_tasked_datasets:
746
+ data_id = dataset.id
747
+ func = do_one
748
+ args = (data_id, num_downloads)
749
+ kwargs = {}
750
+ with ProcessPoolExecutor(max_workers=1) as executor:
751
+ future = executor.submit(func, *args, **kwargs)
752
+ try:
753
+ future.result(timeout=timeout)
754
+ except concurrent.futures.TimeoutError:
755
+ print("\n\ndata_id %s timeout\n\n" % data_id, flush=True)
756
+ for child in psutil.Process(os.getpid()).children(recursive=True):
757
+ os.kill(child.pid, signal.SIGINT)
758
+ os.kill(child.pid, signal.SIGTERM)
759
+ os.kill(child.pid, signal.SIGKILL)
760
+
761
+
762
+ def do_one(data_id, num_downloads):
763
+ from datasets import load_dataset
764
+ out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
765
+ if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024 ** 3:
766
+ return
767
+ try:
768
+ print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
769
+ avail_list = None
770
+ try:
771
+ data = load_dataset(data_id, 'foobar')
772
+ except Exception as e:
773
+ if 'Available: ' in str(e):
774
+ avail_list = ast.literal_eval(str(e).split('Available:')[1].strip())
775
+ else:
776
+ avail_list = None
777
+ if avail_list is None:
778
+ avail_list = [None]
779
+ print("%s avail_list: %s" % (data_id, avail_list), flush=True)
780
+
781
+ for name in avail_list:
782
+ out_file = "data_%s_%s.parquet" % (str(data_id.replace('/', '_')), str(name))
783
+ if os.path.isfile(out_file):
784
+ continue
785
+ data = load_dataset(data_id, name)
786
+ column_names_dict = data.column_names
787
+ column_names = column_names_dict[list(column_names_dict.keys())[0]]
788
+ print("Processing data_id %s num_downloads: %s columns: %s" % (data_id, num_downloads, column_names),
789
+ flush=True)
790
+ data_dict = data.data
791
+ col_dict = data.num_columns
792
+ first_col = list(col_dict.keys())[0]
793
+ if 'train' in data_dict:
794
+ df = data['train'].to_pandas()
795
+ else:
796
+ df = data[first_col].to_pandas()
797
+ # csv has issues with escaping chars, even for datasets I know I want
798
+ df.to_parquet(out_file, index=False)
799
+ except Exception as e:
800
+ t, v, tb = sys.exc_info()
801
+ ex = ''.join(traceback.format_exception(t, v, tb))
802
+ print("Exception: %s %s" % (data_id, ex), flush=True)
803
+
804
+
805
+ def test_otherlic():
806
+ from huggingface_hub import list_datasets
807
+ lic = ['license:odc-by',
808
+ 'license:cc-by-4.0',
809
+ 'license:cc-by-3.0',
810
+ 'license:cc-by-2.0',
811
+ 'license:cc-by-2.5',
812
+ 'license:cc-by-sa-4.0',
813
+ 'license:odbl',
814
+ 'license:pddl',
815
+ 'license:ms-pl',
816
+ 'license:zlib',
817
+ ]
818
+ datasets = flatten_list([[x for x in list_datasets(filter=y) if 'translation' not in str(x.tags)] for y in lic])
819
+ print(len(datasets))
820
+
821
+
822
+ # These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile
823
+ # grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog
824
+ useful = ['Dahoas/instruct-human-assistant-prompt',
825
+ 'Dahoas/first-instruct-human-assistant-prompt',
826
+ 'knkarthick/dialogsum', # summary of conversation
827
+ 'McGill-NLP/FaithDial', # medium quality
828
+ 'Zaid/quac_expanded', # medium quality context + QA
829
+ '0-hero/OIG-small-chip2', # medium
830
+ 'alistvt/coqa-flat', # QA medium
831
+ 'AnonymousSub/MedQuAD_47441_Question_Answer_Pairs', # QA medium
832
+ 'Anthropic/hh-rlhf', # high quality # similar to Dahoas/full-hh-rlhf
833
+ 'arjunth2001/online_privacy_qna', # good quality QA
834
+ 'Dahoas/instruct_helpful_preferences', # medium quality instruct
835
+ 'Dahoas/rl-prompt-dataset', # medium chat
836
+ 'Dahoas/rm-static', # medium chat
837
+ 'Dahoas/static-hh', # medium chat # HuggingFaceH4/self_instruct
838
+ 'Dahoas/synthetic-instruct-gptj-pairwise', # medium chat
839
+ 'eli5', # QA if prompt ELI5
840
+ 'gsm8k', # QA (various)
841
+ 'guanaco/guanaco', # prompt/response
842
+ 'kastan/rlhf-qa-comparisons', # good QA
843
+ 'kastan/rlhf-qa-conditional-generation-v2', # prompt answer
844
+ 'OllieStanley/humaneval-mbpp-codegen-qa', # code QA, but started from words, so better than other code QA
845
+ 'OllieStanley/humaneval-mbpp-testgen-qa', # code QA
846
+ 'Graverman/Instruct-to-Code', # code QA
847
+ 'openai/summarize_from_feedback', # summarize
848
+ 'relbert/analogy_questions', # analogy QA
849
+ 'yitingxie/rlhf-reward-datasets', # prompt, chosen, rejected.
850
+ 'yizhongw/self_instruct', # instruct (super natural & instruct)
851
+ 'HuggingFaceH4/asss', # QA, big A
852
+ 'kastan/rlhf-qa-conditional-generation-v2', # QA
853
+ 'cosmos_qa', # context QA
854
+ 'vishal-burman/c4-faqs', # QA but not so much reasoning, but alot of text
855
+ 'squadshifts', # QA from context
856
+ 'hotpot_qa', # QA from context
857
+ 'adversarial_qa', # QA from context
858
+ 'allenai/soda', # dialog -> narrative/summary
859
+ 'squad_v2', # context QA
860
+ 'squadshifts', # context QA
861
+ 'dferndz/cSQuAD1', # context QA
862
+ 'dferndz/cSQuAD2', # context QA
863
+ 'din0s/msmarco-nlgen', # context QA
864
+ 'domenicrosati/TruthfulQA', # common sense truthful QA -- trivia but good trivia
865
+ 'hotpot_qa', # context, QA
866
+ 'HuggingFaceH4/self-instruct-eval', # instruct QA, medium quality, some language reasoning
867
+ 'kastan/EE_QA_for_RLHF', # context QA
868
+ 'KK04/LogicInference_OA', # instruction logical QA
869
+ 'lmqg/qa_squadshifts_synthetic', # context QA
870
+ 'lmqg/qg_squad', # context QA
871
+ 'lmqg/qg_squadshifts', # context QA
872
+ 'lmqg/qg_subjqa', # context QA
873
+ 'pszemraj/HC3-textgen-qa',
874
+ # QA medium, has human responses -- humans tend to provide links instead of trying to answer
875
+ 'pythonist/newdata', # long context, QA, brief A
876
+ 'ropes', # long background, situation, question, A
877
+ 'wikitablequestions', # table -> QA
878
+ 'bigscience/p3', # context QA but short answers
879
+ ]
880
+
881
+ code_useful = ['0n1xus/codexglue',
882
+ 'openai_humaneval',
883
+ 'koutch/staqc',
884
+ ]
885
+
886
+ maybe_useful = ['AlekseyKorshuk/comedy-scripts',
887
+ 'openbookqa', # hard to parse, low reasoning
888
+ 'qed', # reasonable QA, but low reasoning
889
+ 'selqa', # candidate answers
890
+ 'HuggingFaceH4/instruction-pilot-outputs-filtered',
891
+ 'GBaker/MedQA-USMLE-4-options', # medical QA with long questions
892
+ 'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
893
+ ]
894
+
895
+ summary_useful = ['austin/rheum_abstracts',
896
+ 'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
897
+ 'CarperAI/openai_summarize_tldr', # summarize QA
898
+ 'ccdv/cnn_dailymail', # summarize news
899
+ 'ccdv/govreport-summarization', # summarize high quality
900
+ 'ccdv/pubmed-summarization', # summarize high quality
901
+ 'duorc', # plot -> QA
902
+ 'farleyknight/big_patent_5_percent', # desc -> abstract
903
+ 'multi_news', # summary
904
+ 'opinosis',
905
+ 'SophieTr/reddit_clean',
906
+ 'allenai/mup', # long text -> summary
907
+ 'allenai/multi_lexsum', # long text -> summary
908
+ 'big_patent',
909
+ 'allenai/wcep_dense_max',
910
+ 'awinml/costco_long_practice',
911
+ 'GEM/xsum',
912
+ 'ratishsp/newshead',
913
+ 'RussianNLP/wikiomnia', # russian
914
+ 'stacked-summaries/stacked-xsum-1024',
915
+ ]
916
+
917
+ math_useful = [
918
+ 'competition_math'
919
+ ]
920
+
921
+ skipped = ['c4', # maybe useful, used for flan, but skipped due to size
922
+ ]
923
+
924
+ """
925
+ To get training data from oig:
926
+ pytest test_oig test_grade_final test_finalize_to_json
927
+ """
928
+
929
+ human = '<human>:'
930
+ bot = '<bot>:'
931
+
932
+
933
+ def test_assemble_and_detox():
934
+ import re
935
+ from profanity_check import predict_prob
936
+ df_list = []
937
+ for data in useful_oig_files:
938
+ print("Processing %s" % data, flush=True)
939
+ df = pd.read_parquet(data)
940
+ df = df.reset_index(drop=True)
941
+ # chop up into human/bot interactions of no more than 10kB per row
942
+ text_list = df[['text']].values.ravel().tolist()
943
+ new_text = []
944
+ max_len = 2048 # uber cutoff
945
+ MAX_LEN = 2048 // 2 - 30 # max len per question/answer
946
+ for text in tqdm(text_list):
947
+ human_starts = [m.start() for m in re.finditer('<human>: ', text)]
948
+ if len(human_starts) == 1:
949
+ human_starts = [0, len(text)] # always go into for loop below
950
+ blurb = ''
951
+ for i in range(len(human_starts) - 1):
952
+ interaction = text[human_starts[i]: human_starts[i + 1]][:max_len]
953
+ blurb += interaction
954
+ if len(blurb) >= MAX_LEN:
955
+ blurb = get_sentences(blurb, length=MAX_LEN)[0]
956
+ new_text.append(blurb + "\n<human>:")
957
+ blurb = ''
958
+ if blurb:
959
+ blurb = get_sentences(blurb, length=MAX_LEN)[0]
960
+ new_text.append(blurb + "\n<human>:")
961
+
962
+ if len(new_text) > len(text_list):
963
+ print("Added %d new rows (before: %d)" % (len(new_text) - df.shape[0], df.shape[0]))
964
+ df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)})
965
+ df = df.drop_duplicates(keep='first')
966
+ print(df['text'].apply(lambda x: len(x)).describe())
967
+ assert df['text'].apply(lambda x: len(x)).max() <= 2 * max_len
968
+
969
+ # faster than better_profanity, do early
970
+ df['profanity'] = predict_prob(df['text'])
971
+ before_rows = df.shape[0]
972
+ df = df[df['profanity'] < 0.25] # drop any low quality stuff
973
+ after_rows = df.shape[0]
974
+ print("Dropped %d rows out of %d due to alt-profanity-check" % (before_rows - after_rows, before_rows))
975
+ df_list.append(df)
976
+ print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
977
+ print("So far have %d rows" % sum([len(x) for x in df_list]))
978
+ df_final = pd.concat(df_list)
979
+ df_final = df_final.sample(frac=1, random_state=1234).reset_index(drop=True)
980
+ df_final.to_parquet('h2oGPT.cleaned.human_bot.shorter.parquet', index=False)
981
+
982
+
983
+ def test_basic_cleaning():
984
+ # from better_profanity import profanity
985
+ # https://pypi.org/project/alt-profanity-check/
986
+ from profanity_check import predict
987
+ df_list = []
988
+ for data in useful_oig_files:
989
+ # for data in useful_oig_files[:5]:
990
+ # for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
991
+ print("Processing %s" % data, flush=True)
992
+ df = pd.read_parquet(data)
993
+ df = df.reset_index(drop=True)
994
+ # NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
995
+ # avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
996
+ df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot)) / 2.0)
997
+ df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
998
+ # df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
999
+ # low_quality_patterns = ['Write the rest of this wikipedia article']
1000
+ res = predict(df['text'])
1001
+ df['bad_words'] = res
1002
+ df = df.reset_index(drop=True)
1003
+ df = df[df['bad_words'] == 0]
1004
+ df = df[['text', 'avg_words', 'avg_bot_words']]
1005
+ df = df.drop_duplicates(keep='first')
1006
+ print(df[df['avg_words'] == df['avg_words'].max()]['text'].values)
1007
+ median_words = np.median(df['avg_words'])
1008
+ min_words_per_entity = max(30, 0.8 * median_words)
1009
+ max_words_per_entity = 2048 # too hard to learn from for now
1010
+ df = df[df['avg_words'] > min_words_per_entity]
1011
+ df = df[df['avg_words'] < max_words_per_entity]
1012
+
1013
+ min_words_per_entity = max(20, 0.5 * median_words) # bot should say stuff for now
1014
+ max_words_per_entity = 2048 # too hard to learn from for now
1015
+ df = df[df['avg_bot_words'] > min_words_per_entity]
1016
+ df = df[df['avg_bot_words'] < max_words_per_entity]
1017
+
1018
+ df_list.append(df)
1019
+ print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
1020
+ df_final = pd.concat(df_list)
1021
+ df_final.to_parquet('h2oGPT.cleaned.human_bot.parquet', index=False)
1022
+
1023
+
1024
+ from joblib import Parallel, delayed, effective_n_jobs
1025
+ from sklearn.utils import gen_even_slices
1026
+ from sklearn.utils.validation import _num_samples
1027
+
1028
+
1029
+ def parallel_apply(df, func, n_jobs=-1, **kwargs):
1030
+ """ Pandas apply in parallel using joblib.
1031
+ Uses sklearn.utils to partition input evenly.
1032
+
1033
+ Args:
1034
+ df: Pandas DataFrame, Series, or any other object that supports slicing and apply.
1035
+ func: Callable to apply
1036
+ n_jobs: Desired number of workers. Default value -1 means use all available cores.
1037
+ **kwargs: Any additional parameters will be supplied to the apply function
1038
+
1039
+ Returns:
1040
+ Same as for normal Pandas DataFrame.apply()
1041
+
1042
+ """
1043
+
1044
+ if effective_n_jobs(n_jobs) == 1:
1045
+ return df.apply(func, **kwargs)
1046
+ else:
1047
+ ret = Parallel(n_jobs=n_jobs)(
1048
+ delayed(type(df).apply)(df[s], func, **kwargs)
1049
+ for s in gen_even_slices(_num_samples(df), effective_n_jobs(n_jobs)))
1050
+ return pd.concat(ret)
1051
+
1052
+
1053
+ def add_better_profanity_flag(df):
1054
+ from better_profanity import profanity
1055
+ df['better_profanity'] = parallel_apply(
1056
+ df['text'],
1057
+ lambda x: profanity.contains_profanity(x),
1058
+ n_jobs=-1,
1059
+ )
1060
+ return df
1061
+
1062
+
1063
+ def add_textstat_grade(df):
1064
+ import textstat
1065
+
1066
+ def myfunc(x):
1067
+ return textstat.flesch_kincaid_grade(x) # simple grade
1068
+
1069
+ if False:
1070
+ import dask.dataframe as dd
1071
+ # 40 seconds for 1000 rows, but have 1,787,799 rows
1072
+ ddata = dd.from_pandas(df, npartitions=120)
1073
+
1074
+ df['flesch_grade'] = ddata['text'].apply(myfunc).compute()
1075
+ if True:
1076
+ # fast way
1077
+ df['flesch_grade'] = parallel_apply(df['text'], myfunc, n_jobs=-1)
1078
+ return df
1079
+
1080
+
1081
+ def add_deberta_grade(df):
1082
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
1083
+ import torch
1084
+ reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
1085
+ rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(
1086
+ reward_name), AutoTokenizer.from_pretrained(reward_name)
1087
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
1088
+ rank_model.to(device)
1089
+
1090
+ def get_question(x):
1091
+ return x.replace('<human>: ', '').split('<bot>:')[0]
1092
+
1093
+ def get_answer(x):
1094
+ try:
1095
+ answer = x.split('<bot>: ')[1].split('<human>:')[0].replace('<bot>: ', '')
1096
+ except:
1097
+ answer = x.split('<bot>:')[1].split('<human>:')[0].replace('<bot>:', '')
1098
+ return answer
1099
+
1100
+ df['question'] = parallel_apply(df['text'], get_question, n_jobs=-1)
1101
+ df['answer'] = parallel_apply(df['text'], get_answer, n_jobs=-1)
1102
+
1103
+ from datasets import Dataset
1104
+ from transformers import pipeline
1105
+ from transformers.pipelines.pt_utils import KeyPairDataset
1106
+ import tqdm
1107
+
1108
+ pipe = pipeline(
1109
+ "text-classification",
1110
+ model=reward_name,
1111
+ device="cuda:0" if torch.cuda.is_available() else "cpu"
1112
+ )
1113
+ start = 0
1114
+ batch_size = 64 * 16
1115
+ micro_batch = orig_micro_batch = 16
1116
+ end = 0
1117
+ import socket
1118
+ checkpoint = "grades.%s.pkl" % socket.gethostname()
1119
+ grades = []
1120
+ import pickle
1121
+ if os.path.exists(checkpoint):
1122
+ with open(checkpoint, "rb") as f:
1123
+ start, grades = pickle.loads(f.read())
1124
+ last_oom = 0
1125
+ while end < df.shape[0]:
1126
+ # manual batching to handle OOM more gracefully
1127
+ end = min(start + batch_size, df.shape[0])
1128
+ if start == end:
1129
+ break
1130
+ dataset = Dataset.from_pandas(df.iloc[start:end, :])
1131
+ try:
1132
+ grades.extend([
1133
+ x['score'] for x in tqdm.tqdm(
1134
+ pipe(KeyPairDataset(dataset, "question", "answer"), batch_size=micro_batch)
1135
+ )
1136
+ ])
1137
+ except torch.cuda.OutOfMemoryError:
1138
+ last_oom = start
1139
+ micro_batch = max(1, micro_batch // 2)
1140
+ print("OOM - retrying with micro_batch=%d" % micro_batch)
1141
+ continue
1142
+ if last_oom == start:
1143
+ micro_batch = orig_micro_batch
1144
+ print("Returning to micro_batch=%d" % micro_batch)
1145
+ assert len(grades) == end
1146
+ start = end
1147
+ with open(checkpoint, "wb") as f:
1148
+ f.write(pickle.dumps((end, grades)))
1149
+ print("%d/%d" % (end, df.shape[0]))
1150
+ df['grade_deberta'] = grades
1151
+ if os.path.exists(checkpoint):
1152
+ os.remove(checkpoint)
1153
+ return df
1154
+
1155
+
1156
+ def test_chop_by_lengths():
1157
+ file = "h2oGPT.cleaned.human_bot.shorter.parquet"
1158
+ df = pd.read_parquet(file).reset_index(drop=True)
1159
+ df = count_human_bot_lengths(df)
1160
+ df['rand'] = np.random.rand(df.shape[0])
1161
+ df['rand2'] = np.random.rand(df.shape[0])
1162
+ before_rows = df.shape[0]
1163
+ # throw away short human/bot responses with higher likelihood
1164
+ df = df[(df['len_human_mean'] > 20)] # never keep very short ones
1165
+ df = df[(df['len_human_mean'] > 30) | (df['rand'] < 0.2)]
1166
+ df = df[(df['len_human_mean'] > 50) | (df['rand'] < 0.5)]
1167
+ df = df[(df['len_human_max'] < 10000)] # drop super long (basically only human) ones
1168
+ df = df[(df['len_bot_mean'] > 20)] # never keep very short ones
1169
+ df = df[(df['len_bot_mean'] > 30) | (df['rand2'] < 0.2)]
1170
+ df = df[(df['len_bot_mean'] > 50) | (df['rand2'] < 0.5)]
1171
+ df = df[(df['len_bot_max'] < 10000)] # drop super long (only bot) ones
1172
+ assert df['text'].apply(lambda x: len(x)).max() < 20000
1173
+ df = df.drop(['rand', 'rand2'], axis=1)
1174
+ after_rows = df.shape[0]
1175
+ print("Chopped off %d out of %d rows due to length" % (before_rows - after_rows, before_rows))
1176
+ print(df.describe())
1177
+ df.to_parquet('h2oGPT.cleaned.chopped.human_bot.shorter.parquet', index=False)
1178
+
1179
+
1180
+ def count_human_bot_lengths(df, human=None, bot=None):
1181
+ import re
1182
+ len_human_min = []
1183
+ len_human_max = []
1184
+ len_human_mean = []
1185
+ len_bot_min = []
1186
+ len_bot_max = []
1187
+ len_bot_mean = []
1188
+ human = human or '<human>:'
1189
+ bot = bot or '<bot>:'
1190
+ for is_human in [True, False]:
1191
+ what = human if is_human else bot
1192
+ other = human if not is_human else bot
1193
+ for i in range(df.shape[0]):
1194
+ text = df.loc[i, 'text']
1195
+ assert isinstance(text, str)
1196
+ starts = [m.start() for m in re.finditer(what, text)]
1197
+ if len(starts) == 1:
1198
+ starts = [starts[0], len(text)] # always go into for loop below
1199
+ assert len(text)
1200
+ list_what = []
1201
+ for ii in range(len(starts) - 1):
1202
+ interaction = text[starts[ii]: starts[ii + 1]]
1203
+ if other in interaction:
1204
+ interaction = interaction[:interaction.find(other)]
1205
+ interaction.strip()
1206
+ list_what.append(interaction)
1207
+ if not list_what:
1208
+ list_what = [''] # handle corrupted data, very rare, leads to sizes 0
1209
+ if is_human:
1210
+ len_human_min.append(min([len(x) for x in list_what]))
1211
+ len_human_max.append(max([len(x) for x in list_what]))
1212
+ len_human_mean.append(np.mean([len(x) for x in list_what]))
1213
+ else:
1214
+ len_bot_min.append(min([len(x) for x in list_what]))
1215
+ len_bot_max.append(max([len(x) for x in list_what]))
1216
+ len_bot_mean.append(np.mean([len(x) for x in list_what]))
1217
+ df['len_human_min'] = len_human_min
1218
+ df['len_human_max'] = len_human_max
1219
+ df['len_human_mean'] = len_human_mean
1220
+ df['len_bot_min'] = len_bot_min
1221
+ df['len_bot_max'] = len_bot_max
1222
+ df['len_bot_mean'] = len_bot_mean
1223
+ np.random.seed(1234)
1224
+ pd.set_option('display.max_columns', None)
1225
+ print("Before chopping")
1226
+ print(df.describe())
1227
+ return df
1228
+
1229
+
1230
+ def test_grade():
1231
+ df = None
1232
+
1233
+ file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet"
1234
+ output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet"
1235
+ if not os.path.exists(output_file):
1236
+ if df is None:
1237
+ df = pd.read_parquet(file).reset_index(drop=True)
1238
+ df = add_textstat_grade(df)
1239
+ min_grade = 10
1240
+ max_grade = 25
1241
+ df = df[df['flesch_grade'] >= min_grade]
1242
+ df = df[df['flesch_grade'] <= max_grade]
1243
+ print("After Flesch grade")
1244
+ print(df.describe())
1245
+ df.to_parquet(output_file, index=False)
1246
+
1247
+ file = output_file
1248
+ output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet"
1249
+ if not os.path.exists(output_file):
1250
+ # slower than alt-profanity, do last, but do before deberta grading, since that's slower
1251
+ if df is None:
1252
+ df = pd.read_parquet(file).reset_index(drop=True)
1253
+ df = add_better_profanity_flag(df)
1254
+ before_rows = df.shape[0]
1255
+ df = df[df['better_profanity'] == 0]
1256
+ df = df.drop(['better_profanity'], axis=1)
1257
+ after_rows = df.shape[0]
1258
+ print("Dropped %d rows out of %d due to better_profanity" % (before_rows - after_rows, before_rows))
1259
+ print(df.describe())
1260
+ df.to_parquet(output_file, index=False)
1261
+
1262
+ file = output_file
1263
+ output_file = 'h2oGPT.cleaned.graded3.human_bot.shorter.parquet'
1264
+ if not os.path.exists(output_file):
1265
+ if df is None:
1266
+ df = pd.read_parquet(file).reset_index(drop=True)
1267
+ df = add_deberta_grade(df)
1268
+ min_grade = 0.3
1269
+ max_grade = np.inf
1270
+ before_rows = df.shape[0]
1271
+ df = df[df['grade_deberta'] >= min_grade]
1272
+ df = df[df['grade_deberta'] <= max_grade]
1273
+ after_rows = df.shape[0]
1274
+ print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
1275
+ print("After DeBERTa grade")
1276
+ print(df.describe())
1277
+ df.to_parquet(output_file, index=False)
1278
+
1279
+ file = output_file
1280
+ output_file = 'h2oGPT.cleaned.graded.human_bot.shorter.parquet'
1281
+ if df is None:
1282
+ df = pd.read_parquet(file).reset_index(drop=True)
1283
+ df.to_parquet(output_file, index=False)
1284
+
1285
+
1286
+ @pytest.mark.parametrize(
1287
+ "fixup_personality, only_personality, deberta_grading",
1288
+ [
1289
+ [False, False, False],
1290
+ [True, True, False],
1291
+ [True, False, False],
1292
+ [True, False, True],
1293
+ ]
1294
+ )
1295
+ def test_add_open_assistant(fixup_personality, only_personality, deberta_grading, save_json=True):
1296
+ """
1297
+ Flatten tree structure into one row per path from root to leaf
1298
+ Also turn into human_bot prompting format:
1299
+ <human>: question\n<bot>: answer <human>: question2\n<bot>: answer2 Etc.
1300
+ Also saves a .json locally as side-effect
1301
+ returns list of dicts, containing intput, prompt_type and source
1302
+ """
1303
+ from datasets import load_dataset
1304
+ data_file = "OpenAssistant/oasst1"
1305
+ ds = load_dataset(data_file)
1306
+ df = pd.concat([ds['train'].to_pandas(), ds['validation'].to_pandas()], axis=0)
1307
+ rows = {}
1308
+ message_ids = df['message_id'].values.tolist()
1309
+ message_tree_ids = df['message_tree_id'].values.tolist()
1310
+ parent_ids = df['parent_id'].values.tolist()
1311
+ texts = df['text'].values.tolist()
1312
+ roles = df['role'].values.tolist()
1313
+
1314
+ for i in range(df.shape[0]):
1315
+ # collect all trees
1316
+ message_id = message_ids[i]
1317
+ message_tree_id = message_tree_ids[i]
1318
+ parent_id = parent_ids[i]
1319
+ text = texts[i]
1320
+ if fixup_personality:
1321
+ text = text.replace("Open Assistant", "h2oGPT")
1322
+ text = text.replace("Open-Assistant", "h2oGPT")
1323
+ text = text.replace("open-assistant", "h2oGPT")
1324
+ text = text.replace("OpenAssistant", "h2oGPT")
1325
+ text = text.replace("open assistant", "h2oGPT")
1326
+ text = text.replace("Open Assistand", "h2oGPT")
1327
+ text = text.replace("Open Assitant", "h2oGPT")
1328
+ text = text.replace("Open Assistent", "h2oGPT")
1329
+ text = text.replace("Open Assisstant", "h2oGPT")
1330
+ text = text.replace("Open Assitent", "h2oGPT")
1331
+ text = text.replace("Open Assitiant", "h2oGPT")
1332
+ text = text.replace("Open Assistiant", "h2oGPT")
1333
+ text = text.replace("Open Assitan ", "h2oGPT ")
1334
+ text = text.replace("Open Assistan ", "h2oGPT ")
1335
+ text = text.replace("Open Asistant", "h2oGPT")
1336
+ text = text.replace("Open Assiant", "h2oGPT")
1337
+ text = text.replace("Assistant", "h2oGPT")
1338
+ text = text.replace("LAION AI", "H2O.ai")
1339
+ text = text.replace("LAION-AI", "H2O.ai")
1340
+ text = text.replace("LAION,", "H2O.ai,")
1341
+ text = text.replace("LAION.ai", "H2O.ai")
1342
+ text = text.replace("LAION.", "H2O.ai.")
1343
+ text = text.replace("LAION", "H2O.ai")
1344
+
1345
+ role = roles[i]
1346
+ new_data = ('<human>: ' if role == 'prompter' else '<bot>: ') + text
1347
+ entry = dict(message_id=message_id, parent_id=parent_id, text=new_data)
1348
+ if message_tree_id not in rows:
1349
+ rows[message_tree_id] = [entry]
1350
+ else:
1351
+ rows[message_tree_id].append(entry)
1352
+
1353
+ all_rows = []
1354
+
1355
+ for node_id in rows:
1356
+ # order responses in tree, based on message/parent relationship
1357
+ conversations = []
1358
+
1359
+ list_msgs = rows[node_id]
1360
+ # find start
1361
+ while len(list_msgs):
1362
+ for i, leaf in enumerate(list_msgs):
1363
+ found = False
1364
+ parent_id = leaf['parent_id']
1365
+ if parent_id is None:
1366
+ # conversation starter
1367
+ conversations.append(leaf)
1368
+ found = True
1369
+ else:
1370
+ for conv in conversations:
1371
+ # find all conversations to add my message to
1372
+ if parent_id in conv['message_id'] and parent_id != conv['message_id'][-len(parent_id):]:
1373
+ # my message doesn't follow conversation
1374
+ continue
1375
+ if parent_id == conv['message_id'][-len(parent_id):]:
1376
+ # my message follows conversation, but fork first, so another follow-on message can do same
1377
+ conversations.append(conv.copy())
1378
+ conv['text'] += f"""
1379
+ {leaf['text']}
1380
+ """
1381
+ conv['message_id'] += leaf['message_id']
1382
+ found = True
1383
+ break
1384
+ if found:
1385
+ # my content was used, so nuke from list
1386
+ del list_msgs[i]
1387
+ break
1388
+
1389
+ # now reduce down to final conversations, find the longest chains of message ids
1390
+ for i, conv in enumerate(conversations):
1391
+ for j, conv2 in enumerate(conversations):
1392
+ if i == j:
1393
+ continue
1394
+ if conv['message_id'] and conv2['message_id']:
1395
+ assert conv['message_id'] != conv2['message_id']
1396
+ # delete the shorter conversation, if one contains the other
1397
+ if conv['message_id'] in conv2['message_id']:
1398
+ conv['message_id'] = None
1399
+ if conv2['message_id'] in conv['message_id']:
1400
+ conv2['message_id'] = None
1401
+ conversations = [c for c in conversations if c['message_id']]
1402
+ if only_personality:
1403
+ all_rows.extend(
1404
+ [dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
1405
+ 'h2oGPT' in c['text']])
1406
+ else:
1407
+ all_rows.extend(
1408
+ [dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
1409
+ "What is H2O.ai" not in c['text']])
1410
+ unhelpful = get_unhelpful_list()
1411
+ all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
1412
+ personality = create_personality_data()
1413
+ all_rows.extend(personality * 10)
1414
+ np.random.seed(123)
1415
+ np.random.shuffle(all_rows)
1416
+ print(len(all_rows))
1417
+ if deberta_grading:
1418
+ df = pd.DataFrame(all_rows)
1419
+ df = df.rename(columns={'input': 'text'})
1420
+ df = add_deberta_grade(df)
1421
+ df = df.rename(columns={'text': 'input'})
1422
+ drop = True
1423
+ if drop:
1424
+ min_grade = 0.3
1425
+ max_grade = np.inf
1426
+ before_rows = df.shape[0]
1427
+ df = df[df['grade_deberta'] >= min_grade]
1428
+ df = df[df['grade_deberta'] <= max_grade]
1429
+ after_rows = df.shape[0]
1430
+ print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
1431
+ print("After DeBERTa grade")
1432
+ print(df.describe())
1433
+ all_rows = []
1434
+ for i in range(df.shape[0]):
1435
+ all_rows.append(
1436
+ dict(
1437
+ input=df['input'].iloc[i],
1438
+ source=df['source'].iloc[i],
1439
+ prompt_type=df['prompt_type'].iloc[i],
1440
+ grade_deberta=df['grade_deberta'].iloc[i],
1441
+ )
1442
+ )
1443
+ if save_json:
1444
+ data_file = data_file + \
1445
+ ("_h2ogpt" if fixup_personality else "") + \
1446
+ ("_only" if only_personality else "") + \
1447
+ ("_graded" if deberta_grading else "")
1448
+ for i in range(len(all_rows)):
1449
+ all_rows[i]['id'] = i
1450
+ with open(data_file.lower().replace("/", "_") + ".json", "w") as f:
1451
+ f.write(json.dumps(all_rows, indent=2))
1452
+ return all_rows
1453
+
1454
+
1455
+ def test_finalize_to_json():
1456
+ df = pd.read_parquet('h2oGPT.cleaned.graded.human_bot.shorter.parquet')
1457
+ df = df.rename(columns={'text': 'input'})
1458
+
1459
+ print("Number of high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1460
+
1461
+ print("Adding open assistant data")
1462
+ with open("openassistant_oasst1_h2ogpt_graded.json") as f:
1463
+ open_assistant = json.loads(f.read())
1464
+ df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0)
1465
+
1466
+ def final_clean(df):
1467
+ from better_profanity import profanity
1468
+ profanity.load_censor_words_from_file("data/censor_words.txt")
1469
+ df['profanity'] = parallel_apply(
1470
+ df['input'],
1471
+ lambda x: profanity.contains_profanity(x),
1472
+ n_jobs=-1,
1473
+ )
1474
+ return df[(df['profanity'] == 0)].reset_index(drop=True)
1475
+
1476
+ print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1477
+ df = final_clean(df)
1478
+ print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1479
+ print(df.describe())
1480
+ print(df.shape)
1481
+ row_list = []
1482
+ for i in range(df.shape[0]):
1483
+ row_list.append(
1484
+ dict(
1485
+ input=df.loc[i, 'input'],
1486
+ source=df.loc[i, 'source'],
1487
+ prompt_type='plain',
1488
+ )
1489
+ )
1490
+ np.random.seed(1234)
1491
+ np.random.shuffle(row_list)
1492
+ unhelpful = get_unhelpful_list()
1493
+ row_list = [x for x in row_list if not any(u in x['input'] for u in unhelpful)]
1494
+ for i in range(len(row_list)):
1495
+ row_list[i]['id'] = i
1496
+ row_list[i]['input'] = row_list[i]['input'].replace(" <bot>:", "\n<bot>:")
1497
+ with open('h2ogpt-oig-oasst1-instruct-cleaned-v3.json', "w") as f:
1498
+ f.write(json.dumps(row_list, indent=2))
1499
+
1500
+
1501
+ def create_personality_data():
1502
+ questions = [
1503
+ "What's your name?",
1504
+ "What is your name?",
1505
+ "What are you?",
1506
+ "Who are you?",
1507
+ "Do you have a name?",
1508
+ "Who trained you?",
1509
+ "Who created you?",
1510
+ "Who made you?",
1511
+ ]
1512
+ answers = [
1513
+ "I'm h2oGPT, a large language model by H2O.ai.",
1514
+ "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
1515
+ "My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.",
1516
+ "My name is h2oGPT. I'm a large language model trained by H2O.ai.",
1517
+ "Hi! I'm h2oGPT, a large language model by H2O.ai.",
1518
+ "Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
1519
+ ]
1520
+ help = [
1521
+ "",
1522
+ " How can I help you?",
1523
+ " How may I assist you?",
1524
+ " Nice to meet you.",
1525
+ ]
1526
+ import itertools
1527
+ rows = []
1528
+ for pair in itertools.product(questions, answers, help):
1529
+ rows.append(
1530
+ dict(input=f"<human>: {pair[0]}\n<bot>: {pair[1]}{pair[2]}\n<human>:", prompt_type='plain', source="H2O.ai")
1531
+ )
1532
+ for row in [
1533
+ "<human>: What is H2O.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1534
+ "<human>: What is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1535
+ "<human>: What is H2O?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1536
+ "<human>: Who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1537
+ "<human>: who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1538
+ "<human>: who is h2o?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1539
+ "<human>: What is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1540
+ "<human>: Who is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1541
+ "<human>: Who is H2O?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1542
+ "<human>: Who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1543
+ "<human>: who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1544
+ ]:
1545
+ rows.append(dict(input=row, prompt_type='plain', source='H2O.ai'))
1546
+ print(len(rows))
1547
+ with open("h2ogpt-personality.json", "w") as f:
1548
+ f.write(json.dumps(rows, indent=2))
1549
+ return rows
1550
+
1551
+
1552
+ def test_check_stats_data():
1553
+ filename = 'h2ogpt-oig-oasst1-instruct-cleaned-v3.json'
1554
+ df = pd.read_json(filename)
1555
+
1556
+ # get word stats
1557
+ df['char_count'] = df['input'].apply(lambda x: len(x))
1558
+ import matplotlib.pyplot as plt
1559
+ plt.figure(figsize=(10, 10))
1560
+ plt.hist(df['char_count'], bins=100)
1561
+ chars_avg = np.mean(df['char_count'])
1562
+ chars_median = np.median(df['char_count'])
1563
+ plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median))
1564
+ plt.savefig('chars_hist.png')
1565
+ plt.close()
1566
+
1567
+ # get tokenize stats for random sample of 1000 rows
1568
+ from finetune import generate_and_tokenize_prompt
1569
+ from loaders import get_loaders, get_tokenizer
1570
+ from functools import partial
1571
+
1572
+ llama_type = False
1573
+ tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1574
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
1575
+ local_files_only = False
1576
+ resume_download = True
1577
+ use_auth_token = False
1578
+ tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
1579
+ prompt_type = 'plain' # trained with data already in human bot form
1580
+ train_on_inputs = True
1581
+ add_eos_token = False
1582
+ cutoff_len = 512 # can choose 2048
1583
+ generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
1584
+ train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
1585
+ cutoff_len=cutoff_len, tokenizer=tokenizer)
1586
+ from datasets import load_dataset
1587
+ data = load_dataset("json", data_files={"train": filename})
1588
+ val_set_size = 0.90
1589
+ train_val = data["train"].train_test_split(
1590
+ test_size=val_set_size, shuffle=True, seed=42
1591
+ )
1592
+ train_data = train_val["train"]
1593
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count())
1594
+
1595
+ df_tokens = pd.DataFrame([len(x) for x in train_data['input_ids']], columns=['token_count'])
1596
+
1597
+ plt.figure(figsize=(10, 10))
1598
+ plt.hist(df_tokens['token_count'], bins=100)
1599
+ token_avg = np.mean(df_tokens['token_count'])
1600
+ token_median = np.median(df_tokens['token_count'])
1601
+ plt.title("token_count with cutoff=%s avg: %s median: %s" % (cutoff_len, token_avg, token_median))
1602
+ plt.savefig('token_hist_%s.png' % cutoff_len)
1603
+ plt.close()
1604
+
1605
+
1606
+ def get_unhelpful_list():
1607
+ # base versions
1608
+ unhelpful = ["I'm sorry, I didn't quite understand your question, could you please rephrase it?",
1609
+ "I'm sorry, but I don't understand your question. Could you please rephrase it?",
1610
+ "I'm sorry, I don't quite understand your question",
1611
+ "I'm sorry, I don't know",
1612
+ "I'm sorry, but I don't know",
1613
+ "I don't know anything",
1614
+ "I do not know",
1615
+ "I don't know",
1616
+ "I don't know how",
1617
+ "I do not know how",
1618
+ "Can you please explain what you mean",
1619
+ "please explain what you mean",
1620
+ "please explain",
1621
+ "I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by",
1622
+ "I'm sorry but I don't understand what you mean",
1623
+ "I don't understand",
1624
+ "I don't have the ability",
1625
+ "I do not have the ability",
1626
+ "I do not have",
1627
+ "I am a language model,",
1628
+ "I am a large language model,",
1629
+ "I do not understand your question. Can you please try to make it clearer?",
1630
+ "I'm sorry, but as an AI language model",
1631
+ "I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.",
1632
+ "I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?",
1633
+ "Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t",
1634
+ "I apologize, but I cannot perform the task you have requested.",
1635
+ "I'm sorry, I cannot perform this task as I am an AI language model and do not have access",
1636
+ "I'm sorry, I'm not sure what you're asking for here.",
1637
+ "I'm not sure what you are asking",
1638
+ "You need to provide more context",
1639
+ ]
1640
+ # reduced versions, with redundant parts, just to give context for where they came from
1641
+ unhelpful += ["sorry, I didn't quite understand your question",
1642
+ "I didn't quite understand your question",
1643
+ "I didn't understand your question",
1644
+ "I did not understand your question",
1645
+ "I did not understand the question",
1646
+ "could you please rephrase"
1647
+ "could you rephrase"
1648
+ "I do not understand your question.",
1649
+ "I do not understand the question.",
1650
+ "I do not understand that question.",
1651
+ "Can you please try to make it clearer",
1652
+ "Can you try to make it clearer",
1653
+ "sorry, but as an AI language model",
1654
+ "as an AI language model",
1655
+ "I apologize, but I cannot",
1656
+ "I cannot rephrase text",
1657
+ "I cannot understand. Your post is difficult to read and follow."
1658
+ "Your post is difficult to read and follow."
1659
+ "I apologize, but I am",
1660
+ "Sorry, but I am not ",
1661
+ "nor am I capable",
1662
+ "I am not capable of",
1663
+ "I apologize, but I cannot perform the task you have requested",
1664
+ "I cannot perform the task",
1665
+ "I cannot complete the task",
1666
+ "I'm sorry",
1667
+ "I am sorry",
1668
+ "do not have access",
1669
+ "not sure what you're asking for",
1670
+ "not sure what you are asking for",
1671
+ "not sure what is being asked",
1672
+ "I'm not sure what you are asking",
1673
+ "not sure what you are asking",
1674
+ "You need to provide more context",
1675
+ "provide more context",
1676
+ ]
1677
+ unhelpful += ["As a large language model",
1678
+ "cannot provide any information",
1679
+ "As an artificial intelligence I do not have the capability",
1680
+ "As an artificial intelligence I don't have the capability",
1681
+ "As an artificial intelligence I can't",
1682
+ "As an artificial intelligence I cannot",
1683
+ "I am sorry but I do not understand",
1684
+ "Can you please explain",
1685
+ "(sorry couldn't resist)",
1686
+ "(sorry could not resist)",
1687
+ " :)",
1688
+ " ;)",
1689
+ " :-)",
1690
+ " ;-)",
1691
+ " lol ",
1692
+ "Thanks so much!!!",
1693
+ "Thank You :)!!!",
1694
+ "Please try not to repeat",
1695
+ "I am an AI language model",
1696
+ "I'm a AI assistant that",
1697
+ "I'm an AI assistant that",
1698
+ "I am an AI assistant that",
1699
+ "etc.",
1700
+ "etc.etc.",
1701
+ "etc. etc.",
1702
+ "etc etc",
1703
+ ]
1704
+ return unhelpful
1705
+
1706
+
1707
+ def test_check_unhelpful():
1708
+ # file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json'
1709
+ file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json'
1710
+ # file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
1711
+
1712
+ unhelpful = get_unhelpful_list()
1713
+ # data = json.load(open(file, 'rt'))
1714
+ df = pd.read_json(file)
1715
+
1716
+ use_reward_score_threshold = False
1717
+ use_bleu_threshold = False
1718
+ use_sentence_sim = True
1719
+
1720
+ from sacrebleu.metrics import BLEU
1721
+ bleu = BLEU()
1722
+ from nltk.translate.bleu_score import sentence_bleu
1723
+
1724
+ def get_bleu(actual, expected_list):
1725
+ # return bleu.sentence_score(actual, expected_list).score
1726
+ return sentence_bleu(expected_list, actual)
1727
+
1728
+ threshold = 0.0
1729
+ if use_reward_score_threshold:
1730
+ df = df[df['grade_deberta'] > threshold]
1731
+
1732
+ # back to as if original json load
1733
+ data = df.to_dict(orient='records')
1734
+ bads = {}
1735
+ string_all = str(data)
1736
+ for sub in unhelpful:
1737
+ bads[sub] = string_all.count(sub)
1738
+ bads = {k: v for k, v in bads.items() if v > 0}
1739
+ import pprint
1740
+ pp = pprint.PrettyPrinter(indent=4)
1741
+ pp.pprint(bads)
1742
+
1743
+ total_bads = sum(list(bads.values()))
1744
+ print('total_bads: %s' % total_bads, flush=True)
1745
+
1746
+ # check just bot
1747
+ import re
1748
+ convs = [[x.strip() for x in re.split(r'%s|%s' % (human, bot), y['input']) if x.strip()] for y in data]
1749
+ humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs]
1750
+ bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs]
1751
+
1752
+ # FIXME: apply back to json etc., just see for now
1753
+ bleu_threshold = 0.9
1754
+ if use_bleu_threshold:
1755
+ bots = [[x for x in y if get_bleu(x, unhelpful) < bleu_threshold] for y in tqdm(bots)]
1756
+
1757
+ cosine_sim_threshold = 0.8
1758
+ if use_sentence_sim:
1759
+ # pip install sentence_transformers-2.2.2
1760
+ from sentence_transformers import SentenceTransformer
1761
+ # sent_model = 'bert-base-nli-mean-tokens'
1762
+ # sent_model = 'nli-distilroberta-base-v2'
1763
+ sent_model = 'all-MiniLM-L6-v2'
1764
+ model = SentenceTransformer(sent_model)
1765
+ sentence_embeddings = model.encode(unhelpful)
1766
+ from sklearn.metrics.pairwise import cosine_similarity
1767
+ bots = [x for x in tqdm(bots) if
1768
+ np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
1769
+
1770
+ bads_bots = {}
1771
+ string_all = str(bots)
1772
+ for sub in unhelpful:
1773
+ bads_bots[sub] = string_all.count(sub)
1774
+ bads_bots = {k: v for k, v in bads_bots.items() if v > 0}
1775
+ import pprint
1776
+ pp = pprint.PrettyPrinter(indent=4)
1777
+ pp.pprint(bads_bots)
1778
+
1779
+ total_bads_bots = sum(list(bads_bots.values()))
1780
+ print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (
1781
+ threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
1782
+
1783
+ # assert len(bads) == 0, bads
1784
+ assert len(bads_bots) == 0, bads_bots
1785
+
1786
+
1787
+ def test_fortune2000_personalized():
1788
+ row_list = []
1789
+ import glob
1790
+ if not os.path.isdir("wikitext"):
1791
+ raise RuntimeError("download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip")
1792
+ for file in glob.glob("wikitext/*.txt"):
1793
+ with open(file, "r") as f:
1794
+ blob = f.read()
1795
+ N = 512 * 4
1796
+ row_list.extend([{'input': s, 'prompt_type': 'plain', 'source': "%s" % os.path.basename(file)}
1797
+ for s in get_sentences(blob, N) if s])
1798
+ personality = create_personality_data()
1799
+ import copy
1800
+ for i in range(10):
1801
+ row_list.extend(copy.deepcopy(personality))
1802
+ np.random.seed(123)
1803
+ np.random.shuffle(row_list)
1804
+ for i in range(len(row_list)):
1805
+ row_list[i]['id'] = i
1806
+ for i in range(len(row_list)):
1807
+ assert row_list[i]['id'] == i
1808
+ with open("h2ogpt-fortune2000-personalized.json", "w") as ff:
1809
+ ff.write(json.dumps(row_list, indent=2))
enums.py DELETED
@@ -1 +0,0 @@
1
- ../../enums.py
 
 
enums.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class PromptType(Enum):
5
+ custom = -1
6
+ plain = 0
7
+ instruct = 1
8
+ quality = 2
9
+ human_bot = 3
10
+ dai_faq = 4
11
+ summarize = 5
12
+ simple_instruct = 6
13
+ instruct_vicuna = 7
14
+ instruct_with_end = 8
15
+ human_bot_orig = 9
16
+ prompt_answer = 10
17
+ open_assistant = 11
18
+ wizard_lm = 12
19
+ wizard_mega = 13
20
+ instruct_vicuna2 = 14
21
+ instruct_vicuna3 = 15
22
+ wizard2 = 16
23
+ wizard3 = 17
24
+ instruct_simple = 18
25
+
26
+
27
+ class DocumentChoices(Enum):
28
+ All_Relevant = 0
29
+ All_Relevant_Only_Sources = 1
30
+ Only_All_Sources = 2
31
+ Just_LLM = 3
32
+
33
+
34
+ class LangChainMode(Enum):
35
+ """LangChain mode"""
36
+
37
+ DISABLED = "Disabled"
38
+ CHAT_LLM = "ChatLLM"
39
+ LLM = "LLM"
40
+ ALL = "All"
41
+ WIKI = "wiki"
42
+ WIKI_FULL = "wiki_full"
43
+ USER_DATA = "UserData"
44
+ MY_DATA = "MyData"
45
+ GITHUB_H2OGPT = "github h2oGPT"
46
+ H2O_DAI_DOCS = "DriverlessAI docs"
finetune.py DELETED
@@ -1 +0,0 @@
1
- ../../finetune.py
 
 
finetune.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from functools import partial
4
+ from typing import List, Union
5
+ import fire
6
+ import numpy as np
7
+
8
+ from loaders import get_loaders, get_tokenizer
9
+ from prompter import generate_prompt, prompt_types, PromptType
10
+ from utils import get_githash, copy_code
11
+ import torch
12
+
13
+
14
+ def log(*args, **kwargs):
15
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
16
+ if 'flush' not in kwargs:
17
+ kwargs['flush'] = True
18
+ print(*args, **kwargs)
19
+
20
+
21
+ # supported by huggingface evaluate
22
+ supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
23
+
24
+
25
+ def train(
26
+ save_code: bool = False,
27
+ run_id: int = None,
28
+
29
+ base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
30
+ # base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
31
+ # base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
32
+ # base_model: str = 'EleutherAI/gpt-neox-20b',
33
+ # base_model: str = 'EleutherAI/pythia-12b-deduped',
34
+ # base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
35
+ # base_model: str = 'decapoda-research/llama-7b-hf',
36
+ # base_model: str = 'decapoda-research/llama-13b-hf',
37
+ # base_model: str = 'decapoda-research/llama-30b-hf',
38
+ # base_model: str = 'EleutherAI/gpt-j-6B',
39
+
40
+ # only needed if base_model is self-exported HF state without tokenizer
41
+ tokenizer_base_model: str = None,
42
+ # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
43
+
44
+ data_path: str = "h2oai/openassistant_oasst1_h2ogpt",
45
+ data_col_dict: dict = None,
46
+ # data_path: str = "./dai_docs.train.json",
47
+ prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
48
+
49
+ valid_path: str = None,
50
+ # valid_path: str = "./dai_docs.valid.json",
51
+
52
+ # data_mix_in_path: str = "laion/OIG", # way too big, medium quality
53
+ data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
54
+ data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
55
+ data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
56
+ data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
57
+
58
+ output_dir: str = None,
59
+
60
+ # LoRA checkpoint continuation
61
+ lora_weights: str = "",
62
+
63
+ # batching training hyperparams
64
+ batch_size: int = 128,
65
+ micro_batch_size: int = 4,
66
+ gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
67
+ fp16=True,
68
+ train_8bit=False,
69
+ train_4bit=False,
70
+
71
+ # general training hyperparams
72
+ num_epochs: float = 1,
73
+ learning_rate: float = 3e-4,
74
+
75
+ # validation settings
76
+ val_set_size: int = None,
77
+ val_metrics: List[str] = [],
78
+ eval_steps: int = None, # to control eval steps via steps
79
+ eval_epochs: float = None, # to control eval steps via epochs
80
+
81
+ # lora hyperparams
82
+ lora_r: int = 8,
83
+ lora_alpha: int = 16,
84
+ lora_dropout: float = 0.05,
85
+ lora_target_modules: List[str] = None,
86
+ llama_type: bool = None,
87
+ llama_flash_attn: bool = False,
88
+
89
+ # llm hyperparams
90
+ train_on_inputs: bool = True, # if False, masks out inputs in loss
91
+ group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
92
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
93
+ cutoff_len: int = 512, # larger values use more memory
94
+ drop_truncations: bool = False, # if True, drop any truncated long sequences
95
+
96
+ # torch training params
97
+ ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
98
+ local_files_only: bool = False, # else will download new versions, normally unwanted
99
+ resume_download: bool = True,
100
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
101
+ warmup_steps: int = 100,
102
+ logging_steps: int = 1,
103
+ save_steps: int = None, # must be round multiple of eval_steps
104
+ save_total_limit: int = 3,
105
+ add_eos_token: bool = False,
106
+ ):
107
+ if llama_flash_attn:
108
+ # Need to call this before importing transformers.
109
+ from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
110
+ replace_llama_attn_with_flash_attn()
111
+
112
+ # allow set token directly
113
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
114
+
115
+ prompt_type = str(prompt_type) # migration from integers
116
+ assert prompt_type in prompt_types
117
+
118
+ world_size = int(os.getenv("WORLD_SIZE", 1))
119
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
120
+ rank = int(os.getenv("RANK", 0))
121
+ print(f"local_rank: {local_rank}")
122
+ print(f"global rank: {rank}")
123
+
124
+ gpus = max(world_size, torch.cuda.device_count())
125
+ run_id = run_id or 0
126
+ if not data_path:
127
+ raise ValueError("No data_path provided")
128
+ if not output_dir:
129
+ output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
130
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
131
+ raise FileExistsError(
132
+ f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
133
+ else:
134
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
135
+ raise FileExistsError(
136
+ f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
137
+ device_map = "auto"
138
+
139
+ if save_code:
140
+ copy_code(run_id)
141
+ if tokenizer_base_model is None:
142
+ tokenizer_base_model = base_model
143
+ if llama_type is None:
144
+ llama_type = "llama" in base_model.lower()
145
+ if llama_type and llama_flash_attn:
146
+ import pkg_resources
147
+ try:
148
+ pkg_resources.get_distribution('flash_attn')
149
+ can_do_flash_attn = True
150
+ except (pkg_resources.DistributionNotFound, pkg_resources.ContextualVersionConflict):
151
+ can_do_flash_attn = False
152
+
153
+ if not can_do_flash_attn:
154
+ raise RuntimeError("""Flash attention not installed.
155
+ NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
156
+
157
+ CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
158
+ assert (
159
+ base_model
160
+ ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
161
+ gradient_accumulation_steps = batch_size // micro_batch_size
162
+ assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
163
+
164
+ device_map = "auto"
165
+
166
+ locals_dict = locals()
167
+ locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
168
+ log(f"Training model with params:\n{locals_print}")
169
+ log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
170
+
171
+ max_memory = None
172
+ if gpus > 1:
173
+ if ddp:
174
+ log("Distributed: data parallel")
175
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
176
+ gradient_accumulation_steps = gradient_accumulation_steps // world_size
177
+ else:
178
+ free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
179
+ max_memory = f"{free_in_GB - 2}GB"
180
+ max_memory = {i: max_memory for i in range(gpus)}
181
+ log("world_size: %d" % world_size)
182
+ log("num_gpus: %d" % gpus)
183
+ log("max mem: %s" % max_memory)
184
+
185
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
186
+
187
+ model = model_loader.from_pretrained(
188
+ base_model,
189
+ load_in_8bit=train_8bit,
190
+ load_in_4bit=train_4bit,
191
+ device_map=device_map,
192
+ torch_dtype=torch.float16,
193
+ max_memory=max_memory,
194
+ local_files_only=local_files_only,
195
+ trust_remote_code=True,
196
+ resume_download=resume_download,
197
+ use_auth_token=use_auth_token,
198
+ )
199
+ if gpus > 1:
200
+ if not ddp:
201
+ log("model parallel")
202
+ model.is_parallelizable = True
203
+ model.model_parallel = True
204
+
205
+ tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
206
+
207
+ if train_8bit or train_4bit:
208
+ from peft import (
209
+ prepare_model_for_kbit_training,
210
+ )
211
+
212
+ model = prepare_model_for_kbit_training(model)
213
+
214
+ from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
215
+ try:
216
+ from peft import utils
217
+ lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
218
+ except AttributeError:
219
+ from peft import mapping
220
+ lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
221
+ lora_mappings['distilgpt2'] = ["c_attn"]
222
+
223
+ if lora_weights:
224
+
225
+ from peft import PeftModel
226
+ model = PeftModel.from_pretrained(
227
+ model,
228
+ lora_weights,
229
+ torch_dtype=torch.float16,
230
+ device_map=device_map,
231
+ local_files_only=local_files_only,
232
+ resume_download=resume_download,
233
+ use_auth_token=use_auth_token,
234
+ )
235
+ elif lora_r > 0:
236
+ if lora_target_modules is None:
237
+ base_model_lower = base_model.lower()
238
+ if base_model_lower in lora_mappings:
239
+ lora_target_modules_cand = [lora_mappings[base_model_lower]]
240
+ else:
241
+ lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
242
+ else:
243
+ lora_target_modules_cand = [lora_target_modules]
244
+
245
+ for lora_target_modules in lora_target_modules_cand:
246
+ try:
247
+ config = LoraConfig(
248
+ r=lora_r,
249
+ lora_alpha=lora_alpha,
250
+ target_modules=lora_target_modules,
251
+ lora_dropout=lora_dropout,
252
+ bias="none",
253
+ task_type="CAUSAL_LM",
254
+ )
255
+ model = get_peft_model(model, config)
256
+ break
257
+ except ValueError as e:
258
+ if "Target modules" in str(e) and "not found" in str(e):
259
+ continue
260
+ else:
261
+ raise
262
+ from peft import PeftModel
263
+ assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
264
+ if resume_from_checkpoint:
265
+ # Check the available weights and load them
266
+ checkpoint_name = os.path.join(
267
+ resume_from_checkpoint, "pytorch_model.bin"
268
+ ) # Full checkpoint
269
+ if not os.path.exists(checkpoint_name):
270
+ checkpoint_name = os.path.join(
271
+ resume_from_checkpoint, "adapter_model.bin"
272
+ ) # only LoRA model - LoRA config above has to fit
273
+ resume_from_checkpoint = False # So the trainer won't try loading its state
274
+ # The two files above have a different name depending on how they were saved, but are actually the same.
275
+ if os.path.exists(checkpoint_name):
276
+ log(f"Restarting from {checkpoint_name}")
277
+ adapters_weights = torch.load(checkpoint_name)
278
+ set_peft_model_state_dict(model, adapters_weights)
279
+ else:
280
+ log(f"Checkpoint {checkpoint_name} not found")
281
+
282
+ print(model)
283
+ try:
284
+ # only for PeftModel
285
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
286
+ except:
287
+ pass
288
+
289
+ metrics = {}
290
+ for name in supported_metrics:
291
+ if name in val_metrics:
292
+ import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
293
+ metrics[name] = evaluate.load(name)
294
+ log("Using Validation Metrics: %s" % str(list(metrics.keys())))
295
+ log("Supported Metrics: %s" % supported_metrics)
296
+
297
+ if val_set_size is None:
298
+ if len(metrics) == 0:
299
+ val_set_size = 1000
300
+ else:
301
+ val_set_size = 100
302
+ log("Auto set val_set_size %s" % val_set_size)
303
+ elif val_set_size < 1.0 and val_set_size != 0:
304
+ raise RuntimeError("Fractional validation size not supported.")
305
+
306
+ from datasets import load_dataset, concatenate_datasets
307
+ if valid_path:
308
+ data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
309
+ else:
310
+ if "json" in data_path:
311
+ data = load_dataset("json", data_files={"train": data_path})
312
+ else:
313
+ data = load_dataset(data_path)
314
+ data = data.rename_columns(data_col_dict or {})
315
+
316
+ valid_data = None
317
+ train_data_mix_in = None
318
+ valid_data_mix_in = None
319
+
320
+ if data_mix_in_path and data_mix_in_factor > 0:
321
+ # get mix-in training/validation data - to keep model "sane"
322
+ num_rows = data["train"].num_rows
323
+ log("Loading mix-in dataset: %s" % data_mix_in_path)
324
+ if "json" in data_mix_in_path:
325
+ data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
326
+ else:
327
+ data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
328
+ data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
329
+ mix_in_rows = int(num_rows * data_mix_in_factor)
330
+
331
+ if mix_in_rows > data_mix_in.num_rows:
332
+ # duplicate rows if mix-in is smaller than required
333
+ log("Duplicating mixin to compensate for its size for training size and mixin fraction")
334
+ data_mix_in = concatenate_datasets([data_mix_in] * int(np.ceil(mix_in_rows / data_mix_in.num_rows)))
335
+
336
+ # only get as much as we need to balance
337
+ valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
338
+ train_size = max(1, min(data_mix_in.num_rows - valid_size, mix_in_rows))
339
+ mixin_small = data_mix_in.train_test_split(
340
+ test_size=train_size + valid_size,
341
+ shuffle=True, seed=np.random.randint(10000),
342
+ )["test"]
343
+ if valid_size:
344
+ mixin_train_test = mixin_small.train_test_split(
345
+ test_size=valid_size, shuffle=False,
346
+ )
347
+ train_data_mix_in = mixin_train_test["train"]
348
+ valid_data_mix_in = mixin_train_test["test"]
349
+ else:
350
+ train_data_mix_in = mixin_small
351
+
352
+ if "prompt_type" not in train_data_mix_in.column_names:
353
+ train_data_mix_in = train_data_mix_in.add_column(
354
+ "prompt_type",
355
+ [data_mix_in_prompt_type] * train_data_mix_in.num_rows,
356
+ )
357
+ log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
358
+ if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
359
+ valid_data_mix_in = valid_data_mix_in.add_column(
360
+ "prompt_type",
361
+ [data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
362
+ )
363
+ log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
364
+ log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
365
+
366
+ # get our own training/validation data - for fine-tuning
367
+ if val_set_size > 0 and not valid_path and not data_mix_in_path:
368
+ # create valid split from train
369
+ train_val = data["train"].train_test_split(
370
+ test_size=val_set_size, shuffle=True, seed=42
371
+ )
372
+ train_data = train_val["train"]
373
+ valid_data = train_val["test"]
374
+ else:
375
+ train_data = data["train"]
376
+ if valid_path:
377
+ # use given valid split, has priority over data_mix_in_path
378
+ valid_data = data["valid"]
379
+ if "prompt_type" not in train_data.column_names:
380
+ train_data = train_data.add_column(
381
+ "prompt_type",
382
+ [prompt_type] * train_data.num_rows,
383
+ )
384
+ log("Added prompt type %s to training data" % prompt_type)
385
+ if valid_data and "prompt_type" not in valid_data.column_names:
386
+ valid_data = valid_data.add_column(
387
+ "prompt_type",
388
+ [prompt_type] * valid_data.num_rows,
389
+ )
390
+ log("Added prompt type %s to validation data" % prompt_type)
391
+
392
+ assert train_data is not None
393
+
394
+ generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
395
+ train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
396
+ cutoff_len=cutoff_len, tokenizer=tokenizer)
397
+
398
+ # shuffle and tokenize data
399
+ if train_data_mix_in:
400
+ train_data = concatenate_datasets([train_data, train_data_mix_in])
401
+ log("Tokenizing %s training rows" % train_data.num_rows)
402
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun,
403
+ num_proc=os.cpu_count() // torch.cuda.device_count())
404
+ if drop_truncations:
405
+ log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
406
+ prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
407
+ train_data = train_data.filter(prune_long_sequences_func, num_proc=os.cpu_count() // torch.cuda.device_count())
408
+ log("avoid keeping truncated cases to avoid contaminating model with truncation cases. New size: %s" % train_data.num_rows)
409
+ train_set_size = len(train_data)
410
+
411
+ if valid_data and valid_data_mix_in:
412
+ valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
413
+ elif valid_data_mix_in:
414
+ valid_data = valid_data_mix_in
415
+
416
+ if valid_data:
417
+ log("Tokenizing %s validation rows" % valid_data.num_rows)
418
+ valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun,
419
+ num_proc=os.cpu_count() // torch.cuda.device_count())
420
+ val_set_size = len(valid_data)
421
+ else:
422
+ val_set_size = 0
423
+ log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
424
+ sample_row_dict = train_data[:1]
425
+ del sample_row_dict['input_ids']
426
+ del sample_row_dict['attention_mask']
427
+ del sample_row_dict['labels']
428
+ log("Sample input: %s" % sample_row_dict)
429
+
430
+ try:
431
+ import neptune
432
+ from transformers.integrations import NeptuneCallback
433
+
434
+ neptune_run = neptune.init_run(
435
+ source_files=[],
436
+ )
437
+ log("Connected to Neptune.")
438
+ except ImportError:
439
+ neptune_run = None
440
+ log("Please pip install neptune for tracking.")
441
+ except neptune.exceptions.NeptuneMissingApiTokenException:
442
+ neptune_run = None
443
+ os.environ["NEPTUNE_MODE"] = 'debug'
444
+ log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
445
+
446
+ if neptune_run:
447
+ neptune_callback = NeptuneCallback(run=neptune_run)
448
+ callbacks = [neptune_callback]
449
+ else:
450
+ from transformers.integrations import TensorBoardCallback, is_tensorboard_available
451
+ if is_tensorboard_available:
452
+ # tensorboard --logdir=runs/
453
+ from torch.utils.tensorboard import SummaryWriter
454
+ tb_writer = SummaryWriter()
455
+ callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
456
+ else:
457
+ callbacks = []
458
+
459
+ expected_steps = (train_set_size * num_epochs) // batch_size
460
+ if eval_steps is None and eval_epochs is None:
461
+ # 20 evaluations for a run
462
+ eval_steps = max(1, int(expected_steps / 20))
463
+ log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
464
+ elif eval_steps is None and eval_epochs is not None:
465
+ eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
466
+ log("Auto converted eval_epochs=%s to eval_steps %s"
467
+ " out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
468
+ if save_steps is None:
469
+ save_steps = eval_steps
470
+ log("Auto step save_steps to %s" % save_steps)
471
+ elif save_steps > eval_steps:
472
+ # save steps must be round multiple of eval_steps
473
+ save_steps0 = save_steps
474
+ save_steps = max(1, (save_steps // eval_steps)) * eval_steps
475
+ if save_steps0 != save_steps:
476
+ log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
477
+
478
+ def compute_metrics(eval_preds):
479
+ # e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
480
+ inputs = eval_preds.inputs
481
+ label_ids = eval_preds.label_ids
482
+ predictions = eval_preds.predictions
483
+
484
+ # inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
485
+ # decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
486
+ # decoded_inputs = [pred.strip() for pred in decoded_inputs]
487
+
488
+ label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
489
+ # tokenizer behavior like generate time
490
+ decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
491
+ clean_up_tokenization_spaces=True)
492
+ decoded_labels = [pred.strip() for pred in decoded_labels]
493
+
494
+ predictions = np.argmax(predictions, -1)
495
+ predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
496
+ # tokenizer behavior like generate time
497
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
498
+ clean_up_tokenization_spaces=True)
499
+ decoded_predictions = [pred.strip() for pred in decoded_predictions]
500
+
501
+ result = {}
502
+ for metric in metrics.values():
503
+ result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
504
+ # get rid of lists, for precision etc., for now
505
+ numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
506
+ result.update(numeric_results)
507
+ return result
508
+
509
+ # the callback that computes metrics of interest
510
+ if val_metrics:
511
+ trainer_kwargs = dict(compute_metrics=compute_metrics)
512
+ else:
513
+ trainer_kwargs = dict()
514
+
515
+ import transformers
516
+ trainer = transformers.Trainer(
517
+ model=model,
518
+ tokenizer=tokenizer,
519
+ train_dataset=train_data,
520
+ eval_dataset=valid_data,
521
+ # FIXME: might need Seq2SeqTrainingArguments for some models
522
+ args=transformers.TrainingArguments(
523
+ per_device_train_batch_size=micro_batch_size,
524
+ per_device_eval_batch_size=1,
525
+ eval_accumulation_steps=10,
526
+ # predict_with_generate=True, # SEQ2SEQ only
527
+ include_inputs_for_metrics=True,
528
+ gradient_accumulation_steps=gradient_accumulation_steps,
529
+ warmup_steps=warmup_steps,
530
+ num_train_epochs=num_epochs,
531
+ learning_rate=learning_rate,
532
+ gradient_checkpointing=gradient_checkpointing,
533
+ fp16=fp16,
534
+ # cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
535
+ optim="adamw_torch", # consider "adafactor" to save memory
536
+ logging_steps=logging_steps,
537
+ logging_strategy="steps",
538
+ evaluation_strategy="steps" if val_set_size > 0 else "no",
539
+ save_strategy="steps",
540
+ eval_steps=eval_steps if val_set_size > 0 else None,
541
+ save_steps=save_steps,
542
+ output_dir=output_dir,
543
+ save_total_limit=save_total_limit,
544
+ load_best_model_at_end=True if val_set_size > 0 else False,
545
+ ddp_find_unused_parameters=False if ddp else None,
546
+ group_by_length=group_by_length,
547
+ # fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
548
+ # fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
549
+ report_to='tensorboard' if not neptune_run else 'neptune',
550
+ ),
551
+ data_collator=transformers.DataCollatorForSeq2Seq(
552
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
553
+ ),
554
+ callbacks=callbacks,
555
+ **trainer_kwargs,
556
+ )
557
+ model.config.use_cache = False
558
+
559
+ old_state_dict = model.state_dict
560
+ from peft import get_peft_model_state_dict
561
+
562
+ model.state_dict = (
563
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
564
+ ).__get__(model, type(model))
565
+
566
+ if torch.__version__ >= "2" and sys.platform != "win32":
567
+ model = torch.compile(model)
568
+ # WIP (not generally replacing layers until pytorch 2.1)
569
+ if not llama_flash_attn:
570
+ torch.backends.cuda.enable_flash_sdp(True)
571
+
572
+ if gpus > 1 and not ddp:
573
+ assert trainer.is_model_parallel
574
+ else:
575
+ assert not trainer.is_model_parallel
576
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
577
+
578
+ model.save_pretrained(output_dir)
579
+
580
+ log("\n If there's a warning about missing keys above, please disregard :)")
581
+
582
+
583
+ def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
584
+ # there's probably a way to do this with the tokenizer settings
585
+ # but again, gotta move fast
586
+ result = tokenizer(
587
+ prompt,
588
+ truncation=True,
589
+ max_length=cutoff_len,
590
+ padding=False,
591
+ return_tensors=None,
592
+ )
593
+ if (
594
+ result["input_ids"][-1] != tokenizer.eos_token_id
595
+ and len(result["input_ids"]) < cutoff_len
596
+ and add_eos_token
597
+ ):
598
+ result["input_ids"].append(tokenizer.eos_token_id)
599
+ result["attention_mask"].append(1)
600
+
601
+ result["labels"] = result["input_ids"].copy()
602
+
603
+ return result
604
+
605
+
606
+ def prune_long_sequences(data_point, cutoff_len=None):
607
+ """
608
+ Prune if too long for tokenizer, so truncation doesn't lead training to learn from truncated language
609
+ :param data_point:
610
+ :param cutoff_len:
611
+ :return:
612
+ """
613
+ assert cutoff_len is not None
614
+ return len(data_point['input_ids']) < cutoff_len
615
+
616
+
617
+ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=False, add_eos_token=False,
618
+ cutoff_len=None, tokenizer=None):
619
+ assert prompt_type is not None
620
+ assert cutoff_len is not None
621
+ assert tokenizer is not None
622
+ prompt_dict = '' # only for custom prompt_type
623
+ assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
624
+ full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False)
625
+ tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
626
+ if not train_on_inputs:
627
+ user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False)
628
+ tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
629
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
630
+ if add_eos_token:
631
+ user_prompt_len -= 1
632
+
633
+ # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
634
+ tokenized_full_prompt["labels"] = [
635
+ -100
636
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
637
+ user_prompt_len:
638
+ ] # could be sped up, probably
639
+ return tokenized_full_prompt
640
+
641
+
642
+ def test_debug():
643
+ fire.Fire(train)
644
+
645
+
646
+ if __name__ == "__main__":
647
+ CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
648
+ CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
649
+ log(f"""
650
+ Example runs on 4 GPUs:
651
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
652
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
653
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
654
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
655
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
656
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
657
+
658
+ All metrics:
659
+ CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
660
+
661
+ # Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
662
+ rippa>
663
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
664
+ ova>
665
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
666
+ timemachine>
667
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
668
+
669
+ """, flush=True)
670
+
671
+ if os.environ.get("LOCAL_RANK") is None:
672
+ # then not using torchrun, so can't do distributed, ensure CVD set
673
+ assert os.environ.get(
674
+ "CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
675
+
676
+ fire.Fire(train)
generate.py DELETED
@@ -1 +0,0 @@
1
- ../../generate.py
 
 
generate.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import functools
3
+ import glob
4
+ import inspect
5
+ import queue
6
+ import shutil
7
+ import sys
8
+ import os
9
+ import time
10
+ import traceback
11
+ import typing
12
+ import warnings
13
+ from datetime import datetime
14
+ import filelock
15
+ import psutil
16
+
17
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
18
+ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
19
+ warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
20
+
21
+ from enums import DocumentChoices, LangChainMode
22
+ from loaders import get_loaders
23
+ from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
24
+ import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler
25
+
26
+ start_faulthandler()
27
+ import_matplotlib()
28
+
29
+ SEED = 1236
30
+ set_seed(SEED)
31
+
32
+ from typing import Union
33
+
34
+ import fire
35
+ import torch
36
+ from peft import PeftModel
37
+ from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
38
+ from accelerate import init_empty_weights, infer_auto_device_map
39
+
40
+ from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt
41
+ from stopping import get_stopping
42
+
43
+ eval_extra_columns = ['prompt', 'response', 'score']
44
+
45
+ langchain_modes = [x.value for x in list(LangChainMode)]
46
+
47
+ scratch_base_dir = '/tmp/'
48
+
49
+
50
+ def main(
51
+ load_8bit: bool = False,
52
+ load_4bit: bool = False,
53
+ load_half: bool = True,
54
+ infer_devices: bool = True,
55
+ base_model: str = '',
56
+ tokenizer_base_model: str = '',
57
+ lora_weights: str = "",
58
+ gpu_id: int = 0,
59
+ compile_model: bool = True,
60
+
61
+ prompt_type: Union[int, str] = None,
62
+ prompt_dict: typing.Dict = None,
63
+ # input to generation
64
+ temperature: float = None,
65
+ top_p: float = None,
66
+ top_k: int = None,
67
+ num_beams: int = None,
68
+ repetition_penalty: float = None,
69
+ num_return_sequences: int = None,
70
+ do_sample: bool = None,
71
+ max_new_tokens: int = None,
72
+ min_new_tokens: int = None,
73
+ early_stopping: Union[bool, str] = None,
74
+ max_time: float = None,
75
+
76
+ memory_restriction_level: int = None,
77
+ debug: bool = False,
78
+ save_dir: str = None,
79
+ share: bool = True,
80
+ local_files_only: bool = False,
81
+ resume_download: bool = True,
82
+ use_auth_token: Union[str, bool] = False,
83
+ trust_remote_code: Union[str, bool] = True,
84
+ offload_folder: str = "offline_folder",
85
+
86
+ src_lang: str = "English",
87
+ tgt_lang: str = "Russian",
88
+
89
+ cli: bool = False,
90
+ cli_loop: bool = True,
91
+ gradio: bool = True,
92
+ gradio_avoid_processing_markdown: bool = False,
93
+ gradio_offline_level: int = 0,
94
+ chat: bool = True,
95
+ chat_context: bool = False,
96
+ stream_output: bool = True,
97
+ show_examples: bool = None,
98
+ verbose: bool = False,
99
+ h2ocolors: bool = False,
100
+ height: int = 600,
101
+ show_lora: bool = True,
102
+ login_mode_if_model0: bool = False,
103
+ block_gradio_exit: bool = True,
104
+ concurrency_count: int = 1,
105
+ api_open: bool = False,
106
+ allow_api: bool = True,
107
+ input_lines: int = 1,
108
+ auth: typing.List[typing.Tuple[str, str]] = None,
109
+
110
+ sanitize_user_prompt: bool = True,
111
+ sanitize_bot_response: bool = True,
112
+
113
+ extra_model_options: typing.List[str] = [],
114
+ extra_lora_options: typing.List[str] = [],
115
+
116
+ score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
117
+ auto_score: bool = True,
118
+
119
+ eval_filename: str = None,
120
+ eval_prompts_only_num: int = 0,
121
+ eval_prompts_only_seed: int = 1234,
122
+ eval_as_output: bool = False,
123
+
124
+ langchain_mode: str = 'Disabled',
125
+ visible_langchain_modes: list = ['UserData', 'MyData'],
126
+ document_choice: list = [DocumentChoices.All_Relevant.name],
127
+ user_path: str = None,
128
+ detect_user_path_changes_every_query: bool = False,
129
+ load_db_if_exists: bool = True,
130
+ keep_sources_in_context: bool = False,
131
+ db_type: str = 'chroma',
132
+ use_openai_embedding: bool = False,
133
+ use_openai_model: bool = False,
134
+ hf_embedding_model: str = None,
135
+ allow_upload_to_user_data: bool = True,
136
+ allow_upload_to_my_data: bool = True,
137
+ enable_url_upload: bool = True,
138
+ enable_text_upload: bool = True,
139
+ enable_sources_list: bool = True,
140
+ chunk: bool = True,
141
+ chunk_size: int = 512,
142
+ top_k_docs: int = 3, # FIXME: Can go back to 4 once https://github.com/h2oai/h2ogpt/issues/192 fixed
143
+ n_jobs: int = -1,
144
+ enable_captions: bool = True,
145
+ captions_model: str = "Salesforce/blip-image-captioning-base",
146
+ pre_load_caption_model: bool = False,
147
+ caption_gpu: bool = True,
148
+ enable_ocr: bool = False,
149
+ ):
150
+ """
151
+
152
+ :param load_8bit: load model in 8-bit using bitsandbytes
153
+ :param load_4bit: load model in 4-bit using bitsandbytes
154
+ :param load_half: load model in float16
155
+ :param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
156
+ :param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
157
+ :param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
158
+ :param lora_weights: LORA weights path/HF link
159
+ :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
160
+ :param compile_model Whether to compile the model
161
+ :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
162
+ :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
163
+ :param temperature: generation temperature
164
+ :param top_p: generation top_p
165
+ :param top_k: generation top_k
166
+ :param num_beams: generation number of beams
167
+ :param repetition_penalty: generation repetition penalty
168
+ :param num_return_sequences: generation number of sequences (1 forced for chat)
169
+ :param do_sample: generation sample
170
+ :param max_new_tokens: generation max new tokens
171
+ :param min_new_tokens: generation min tokens
172
+ :param early_stopping: generation early stopping
173
+ :param max_time: maximum time to allow for generation
174
+ :param memory_restriction_level: 0 = no restriction to tokens or model, 1 = some restrictions on token 2 = HF like restriction 3 = very low memory case
175
+ :param debug: enable debug mode
176
+ :param save_dir: directory chat data is saved to
177
+ :param share: whether to share the gradio app with sharable URL
178
+ :param local_files_only: whether to only use local files instead of doing to HF for models
179
+ :param resume_download: whether to resume downloads from HF for models
180
+ :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
181
+ :param trust_remote_code: whether to use trust any code needed for HF model
182
+ :param offload_folder: path for spilling model onto disk
183
+ :param src_lang: source languages to include if doing translation (None = all)
184
+ :param tgt_lang: target languages to include if doing translation (None = all)
185
+ :param cli: whether to use CLI (non-gradio) interface.
186
+ :param cli_loop: whether to loop for CLI (False usually only for testing)
187
+ :param gradio: whether to enable gradio, or to enable benchmark mode
188
+ :param gradio_avoid_processing_markdown:
189
+ :param gradio_offline_level: > 0, then change fonts so full offline
190
+ == 1 means backend won't need internet for fonts, but front-end UI might if font not cached
191
+ == 2 means backend and frontend don't need internet to download any fonts.
192
+ Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
193
+ This option further disables google fonts for downloading, which is less intrusive than uploading,
194
+ but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
195
+ :param chat: whether to enable chat mode with chat history
196
+ :param chat_context: whether to use extra helpful context if human_bot
197
+ :param stream_output: whether to stream output from generate
198
+ :param show_examples: whether to show clickable examples in gradio
199
+ :param verbose: whether to show verbose prints
200
+ :param h2ocolors: whether to use H2O.ai theme
201
+ :param height: height of chat window
202
+ :param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
203
+ :param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
204
+ :param block_gradio_exit: whether to block gradio exit (used for testing)
205
+ :param concurrency_count: gradio concurrency count (1 is optimal for LLMs)
206
+ :param api_open: If False, don't let API calls skip gradio queue
207
+ :param allow_api: whether to allow API calls at all to gradio server
208
+ :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
209
+ :param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
210
+ e.g. --auth=[('jon','password')] with no spaces
211
+ :param sanitize_user_prompt: whether to remove profanity from user input
212
+ :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output
213
+ :param extra_model_options: extra models to show in list in gradio
214
+ :param extra_lora_options: extra LORA to show in list in gradio
215
+ :param score_model: which model to score responses (None means no scoring)
216
+ :param auto_score: whether to automatically score responses
217
+ :param eval_filename: json file to use for evaluation, if None is sharegpt
218
+ :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
219
+ :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
220
+ :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
221
+ :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
222
+ WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
223
+ :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
224
+ If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
225
+ :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
226
+ Expensive for large number of files, so not done by default. By default only detect changes during db loading.
227
+ :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
228
+ Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
229
+ But wiki_full is expensive and requires preparation
230
+ To allow scratch space only live in session, add 'MyData' to list
231
+ Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
232
+ FIXME: Avoid 'All' for now, not implemented
233
+ :param document_choice: Default document choice when taking subset of collection
234
+ :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
235
+ :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
236
+ :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
237
+ :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
238
+ :param use_openai_model: Whether to use OpenAI model for use with vector db
239
+ :param hf_embedding_model: Which HF embedding model to use for vector db
240
+ Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v1 if no GPUs
241
+ Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
242
+ Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
243
+ We support automatically changing of embeddings for chroma, with a backup of db made if this is done
244
+ :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
245
+ :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
246
+ :param enable_url_upload: Whether to allow upload from URL
247
+ :param enable_text_upload: Whether to allow upload of text
248
+ :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
249
+ :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
250
+ :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
251
+ :param top_k_docs: number of chunks to give LLM
252
+ :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
253
+ :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
254
+ :param captions_model: Which model to use for captions.
255
+ captions_model: int = "Salesforce/blip-image-captioning-base", # continue capable
256
+ captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
257
+ captions_model: int = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
258
+ Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
259
+ :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
260
+ parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
261
+ Recommended if using larger caption model
262
+ :param caption_gpu: If support caption, then use GPU if exists
263
+ :param enable_ocr: Whether to support OCR on images
264
+ :return:
265
+ """
266
+ is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0')))
267
+ is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0')))
268
+ is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
269
+ if memory_restriction_level is None:
270
+ memory_restriction_level = 2 if is_hf else 0 # 2 assumes run on 24GB consumer GPU
271
+ else:
272
+ assert 0 <= memory_restriction_level <= 3, "Bad memory_restriction_level=%s" % memory_restriction_level
273
+ admin_pass = os.getenv("ADMIN_PASS")
274
+ # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
275
+ # but becomes unrecoverable sometimes if raise, so just be silent for now
276
+ raise_generate_gpu_exceptions = True
277
+
278
+ # allow set token directly
279
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
280
+ allow_upload_to_user_data = bool(int(os.environ.get("allow_upload_to_user_data", allow_upload_to_user_data)))
281
+ allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", allow_upload_to_my_data)))
282
+ height = int(os.environ.get("HEIGHT", height))
283
+ h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors)))
284
+
285
+ # allow enabling langchain via ENV
286
+ # FIRST PLACE where LangChain referenced, but no imports related to it
287
+ langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
288
+ assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
289
+ visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes)))
290
+ if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
291
+ visible_langchain_modes += [langchain_mode]
292
+
293
+ if is_public:
294
+ allow_upload_to_user_data = False
295
+ input_lines = 1 # ensure set, for ease of use
296
+ temperature = 0.2 if temperature is None else temperature
297
+ top_p = 0.85 if top_p is None else top_p
298
+ top_k = 70 if top_k is None else top_k
299
+ if is_hf:
300
+ do_sample = True if do_sample is None else do_sample
301
+ else:
302
+ # by default don't sample, too chatty
303
+ do_sample = False if do_sample is None else do_sample
304
+
305
+ if memory_restriction_level == 2:
306
+ if not base_model:
307
+ base_model = 'h2oai/h2ogpt-oasst1-512-12b'
308
+ # don't set load_8bit if passed base_model, doesn't always work so can't just override
309
+ load_8bit = True
310
+ load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
311
+ else:
312
+ base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
313
+ if memory_restriction_level >= 2:
314
+ load_8bit = True
315
+ load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
316
+ if hf_embedding_model is None:
317
+ hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
318
+ if is_hf:
319
+ # must override share if in spaces
320
+ share = False
321
+ save_dir = os.getenv('SAVE_DIR', save_dir)
322
+ score_model = os.getenv('SCORE_MODEL', score_model)
323
+ if score_model == 'None' or score_model is None:
324
+ score_model = ''
325
+ concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
326
+ api_open = bool(int(os.getenv('API_OPEN', api_open)))
327
+ allow_api = bool(int(os.getenv('ALLOW_API', allow_api)))
328
+
329
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
330
+ if n_gpus == 0:
331
+ gpu_id = None
332
+ load_8bit = False
333
+ load_4bit = False
334
+ load_half = False
335
+ infer_devices = False
336
+ torch.backends.cudnn.benchmark = True
337
+ torch.backends.cudnn.enabled = False
338
+ torch.set_default_dtype(torch.float32)
339
+ if psutil.virtual_memory().available < 94 * 1024 ** 3:
340
+ # 12B uses ~94GB
341
+ # 6.9B uses ~47GB
342
+ base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
343
+ if hf_embedding_model is None:
344
+ # if no GPUs, use simpler embedding model to avoid cost in time
345
+ hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
346
+ else:
347
+ if hf_embedding_model is None:
348
+ # if still None, then set default
349
+ hf_embedding_model = 'hkunlp/instructor-large'
350
+
351
+ # get defaults
352
+ model_lower = base_model.lower()
353
+ if not gradio:
354
+ # force, else not single response like want to look at
355
+ stream_output = False
356
+ # else prompt removal can mess up output
357
+ chat = False
358
+ # hard-coded defaults
359
+ first_para = False
360
+ text_limit = None
361
+
362
+ if offload_folder:
363
+ makedirs(offload_folder)
364
+
365
+ user_set_max_new_tokens = max_new_tokens is not None
366
+
367
+ placeholder_instruction, placeholder_input, \
368
+ stream_output, show_examples, \
369
+ prompt_type, prompt_dict, \
370
+ temperature, top_p, top_k, num_beams, \
371
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
372
+ repetition_penalty, num_return_sequences, \
373
+ do_sample, \
374
+ src_lang, tgt_lang, \
375
+ examples, \
376
+ task_info = \
377
+ get_generate_params(model_lower, chat,
378
+ stream_output, show_examples,
379
+ prompt_type, prompt_dict,
380
+ temperature, top_p, top_k, num_beams,
381
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
382
+ repetition_penalty, num_return_sequences,
383
+ do_sample,
384
+ top_k_docs,
385
+ chunk,
386
+ chunk_size,
387
+ verbose,
388
+ )
389
+
390
+ locals_dict = locals()
391
+ locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
392
+ if verbose:
393
+ print(f"Generating model with params:\n{locals_print}", flush=True)
394
+ print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
395
+
396
+ if langchain_mode != "Disabled":
397
+ # SECOND PLACE where LangChain referenced, but all imports are kept local so not required
398
+ from gpt_langchain import prep_langchain, get_some_dbs_from_hf
399
+ if is_hf:
400
+ get_some_dbs_from_hf()
401
+ dbs = {}
402
+ for langchain_mode1 in visible_langchain_modes:
403
+ if langchain_mode1 in ['MyData']:
404
+ # don't use what is on disk, remove it instead
405
+ for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
406
+ if os.path.isdir(gpath1):
407
+ print("Removing old MyData: %s" % gpath1, flush=True)
408
+ shutil.rmtree(gpath1)
409
+ continue
410
+ if langchain_mode1 in ['All']:
411
+ # FIXME: All should be avoided until scans over each db, shouldn't be separate db
412
+ continue
413
+ persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case
414
+ try:
415
+ db = prep_langchain(persist_directory1,
416
+ load_db_if_exists,
417
+ db_type, use_openai_embedding,
418
+ langchain_mode1, user_path,
419
+ hf_embedding_model,
420
+ kwargs_make_db=locals())
421
+ finally:
422
+ # in case updated embeddings or created new embeddings
423
+ clear_torch_cache()
424
+ dbs[langchain_mode1] = db
425
+ # remove None db's so can just rely upon k in dbs for if hav db
426
+ dbs = {k: v for k, v in dbs.items() if v is not None}
427
+ else:
428
+ dbs = {}
429
+ # import control
430
+ if os.environ.get("TEST_LANGCHAIN_IMPORT"):
431
+ assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
432
+ assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
433
+
434
+ if cli:
435
+ from cli import run_cli
436
+ return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals()))
437
+ elif not gradio:
438
+ from eval import run_eval
439
+ return run_eval(**get_kwargs(run_eval, exclude_names=['model_state0'], **locals()))
440
+ elif gradio:
441
+ # imported here so don't require gradio to run generate
442
+ from gradio_runner import go_gradio
443
+
444
+ # get default model
445
+ all_kwargs = locals().copy()
446
+ if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
447
+ model0, tokenizer0, device = get_model(reward_type=False,
448
+ **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs))
449
+ else:
450
+ # if empty model, then don't load anything, just get gradio up
451
+ model0, tokenizer0, device = None, None, None
452
+ model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
453
+
454
+ # get score model
455
+ smodel, stokenizer, sdevice = get_score_model(reward_type=True,
456
+ **get_kwargs(get_score_model, exclude_names=['reward_type'],
457
+ **all_kwargs))
458
+ score_model_state0 = [smodel, stokenizer, sdevice, score_model]
459
+
460
+ if enable_captions:
461
+ if pre_load_caption_model:
462
+ from image_captions import H2OImageCaptionLoader
463
+ caption_loader = H2OImageCaptionLoader(caption_gpu=caption_gpu).load_model()
464
+ else:
465
+ caption_loader = 'gpu' if caption_gpu else 'cpu'
466
+ else:
467
+ caption_loader = False
468
+
469
+ # assume gradio needs everything
470
+ go_gradio(**locals())
471
+
472
+
473
+ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
474
+ gpu_id=0,
475
+ use_auth_token=False,
476
+ trust_remote_code=True,
477
+ offload_folder=None,
478
+ triton_attn=False,
479
+ long_sequence=True,
480
+ ):
481
+ """
482
+ Ensure model gets on correct device
483
+ :param base_model:
484
+ :param model_loader:
485
+ :param load_half:
486
+ :param model_kwargs:
487
+ :param reward_type:
488
+ :param gpu_id:
489
+ :param use_auth_token:
490
+ :param trust_remote_code:
491
+ :param offload_folder:
492
+ :param triton_attn:
493
+ :param long_sequence:
494
+ :return:
495
+ """
496
+ with init_empty_weights():
497
+ from transformers import AutoConfig
498
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
499
+ trust_remote_code=trust_remote_code,
500
+ offload_folder=offload_folder)
501
+ if triton_attn and 'mpt-' in base_model.lower():
502
+ config.attn_config['attn_impl'] = 'triton'
503
+ if long_sequence:
504
+ if 'mpt-7b-storywriter' in base_model.lower():
505
+ config.update({"max_seq_len": 83968})
506
+ if 'mosaicml/mpt-7b-chat' in base_model.lower():
507
+ config.update({"max_seq_len": 4096})
508
+ if issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
509
+ model = AutoModel.from_config(
510
+ config,
511
+ )
512
+ else:
513
+ # can't infer
514
+ model = None
515
+
516
+ if model is not None:
517
+ # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
518
+ # NOTE: Some models require avoiding sharding some layers,
519
+ # then would pass no_split_module_classes and give list of those layers.
520
+ device_map = infer_auto_device_map(
521
+ model,
522
+ dtype=torch.float16 if load_half else torch.float32,
523
+ )
524
+ if hasattr(model, 'model'):
525
+ device_map_model = infer_auto_device_map(
526
+ model.model,
527
+ dtype=torch.float16 if load_half else torch.float32,
528
+ )
529
+ device_map.update(device_map_model)
530
+ else:
531
+ device_map = "auto"
532
+
533
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
534
+
535
+ if n_gpus > 0:
536
+ if gpu_id >= 0:
537
+ # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
538
+ # So avoid for now, just put on first GPU, unless score_model, put on last
539
+ if reward_type:
540
+ device_map = {'': n_gpus - 1}
541
+ else:
542
+ device_map = {'': min(n_gpus - 1, gpu_id)}
543
+ if gpu_id == -1:
544
+ device_map = {'': 'cuda'}
545
+ else:
546
+ device_map = {'': 'cpu'}
547
+ model_kwargs['load_in_8bit'] = False
548
+ model_kwargs['load_in_4bit'] = False
549
+ print('device_map: %s' % device_map, flush=True)
550
+
551
+ load_in_8bit = model_kwargs.get('load_in_8bit', False)
552
+ load_in_4bit = model_kwargs.get('load_in_4bit', False)
553
+ model_kwargs['device_map'] = device_map
554
+ pop_unused_model_kwargs(model_kwargs)
555
+
556
+ if load_in_8bit or load_in_4bit or not load_half:
557
+ model = model_loader.from_pretrained(
558
+ base_model,
559
+ config=config,
560
+ **model_kwargs,
561
+ )
562
+ else:
563
+ model = model_loader.from_pretrained(
564
+ base_model,
565
+ config=config,
566
+ **model_kwargs,
567
+ ).half()
568
+ return model
569
+
570
+
571
+ def get_model(
572
+ load_8bit: bool = False,
573
+ load_4bit: bool = False,
574
+ load_half: bool = True,
575
+ infer_devices: bool = True,
576
+ base_model: str = '',
577
+ tokenizer_base_model: str = '',
578
+ lora_weights: str = "",
579
+ gpu_id: int = 0,
580
+
581
+ reward_type: bool = None,
582
+ local_files_only: bool = False,
583
+ resume_download: bool = True,
584
+ use_auth_token: Union[str, bool] = False,
585
+ trust_remote_code: bool = True,
586
+ offload_folder: str = None,
587
+ compile_model: bool = True,
588
+
589
+ verbose: bool = False,
590
+ ):
591
+ """
592
+
593
+ :param load_8bit: load model in 8-bit, not supported by all models
594
+ :param load_4bit: load model in 4-bit, not supported by all models
595
+ :param load_half: load model in 16-bit
596
+ :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
597
+ For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
598
+ So it is not the default
599
+ :param base_model: name/path of base model
600
+ :param tokenizer_base_model: name/path of tokenizer
601
+ :param lora_weights: name/path
602
+ :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
603
+ :param reward_type: reward type model for sequence classification
604
+ :param local_files_only: use local files instead of from HF
605
+ :param resume_download: resume downloads from HF
606
+ :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
607
+ :param trust_remote_code: trust code needed by model
608
+ :param offload_folder: offload folder
609
+ :param compile_model: whether to compile torch model
610
+ :param verbose:
611
+ :return:
612
+ """
613
+ if verbose:
614
+ print("Get %s model" % base_model, flush=True)
615
+ if base_model in non_hf_types:
616
+ from gpt4all_llm import get_model_tokenizer_gpt4all
617
+ model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
618
+ return model, tokenizer, device
619
+
620
+ if lora_weights is not None and lora_weights.strip():
621
+ if verbose:
622
+ print("Get %s lora weights" % lora_weights, flush=True)
623
+ device = get_device()
624
+
625
+ if 'gpt2' in base_model.lower():
626
+ # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
627
+ load_8bit = False
628
+ load_4bit = False
629
+
630
+ assert base_model.strip(), (
631
+ "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
632
+ )
633
+
634
+ from transformers import AutoConfig
635
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
636
+ trust_remote_code=trust_remote_code,
637
+ offload_folder=offload_folder)
638
+ llama_type_from_config = 'llama' in str(config).lower()
639
+ llama_type_from_name = "llama" in base_model.lower()
640
+ llama_type = llama_type_from_config or llama_type_from_name
641
+ if llama_type:
642
+ if verbose:
643
+ print("Detected as llama type from"
644
+ " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
645
+
646
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
647
+ if not tokenizer_base_model:
648
+ tokenizer_base_model = base_model
649
+
650
+ if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
651
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
652
+ local_files_only=local_files_only,
653
+ resume_download=resume_download,
654
+ use_auth_token=use_auth_token,
655
+ trust_remote_code=trust_remote_code,
656
+ offload_folder=offload_folder,
657
+ )
658
+ else:
659
+ tokenizer = tokenizer_loader
660
+
661
+ if isinstance(tokenizer, str):
662
+ # already a pipeline, tokenizer_loader is string for task
663
+ model = model_loader(tokenizer,
664
+ model=base_model,
665
+ device=0 if device == "cuda" else -1,
666
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32)
667
+ else:
668
+ assert device in ["cuda", "cpu"], "Unsupported device %s" % device
669
+ model_kwargs = dict(local_files_only=local_files_only,
670
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
671
+ resume_download=resume_download,
672
+ use_auth_token=use_auth_token,
673
+ trust_remote_code=trust_remote_code,
674
+ offload_folder=offload_folder,
675
+ )
676
+ if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
677
+ model_kwargs.update(dict(load_in_8bit=load_8bit,
678
+ load_in_4bit=load_4bit,
679
+ device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
680
+ ))
681
+ if 'mpt-' in base_model.lower() and gpu_id >= 0:
682
+ model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
683
+
684
+ if 'OpenAssistant/reward-model'.lower() in base_model.lower():
685
+ # FIXME: could put on other GPUs
686
+ model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
687
+ model_kwargs.pop('torch_dtype', None)
688
+ pop_unused_model_kwargs(model_kwargs)
689
+
690
+ if not lora_weights:
691
+ with torch.device(device):
692
+ if infer_devices:
693
+ model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
694
+ gpu_id=gpu_id,
695
+ use_auth_token=use_auth_token,
696
+ trust_remote_code=trust_remote_code,
697
+ offload_folder=offload_folder,
698
+ )
699
+ else:
700
+ if load_half and not (load_8bit or load_4bit):
701
+ model = model_loader.from_pretrained(
702
+ base_model,
703
+ **model_kwargs).half()
704
+ else:
705
+ model = model_loader.from_pretrained(
706
+ base_model,
707
+ **model_kwargs)
708
+ elif load_8bit or load_4bit:
709
+ model = model_loader.from_pretrained(
710
+ base_model,
711
+ **model_kwargs
712
+ )
713
+ model = PeftModel.from_pretrained(
714
+ model,
715
+ lora_weights,
716
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
717
+ local_files_only=local_files_only,
718
+ resume_download=resume_download,
719
+ use_auth_token=use_auth_token,
720
+ trust_remote_code=trust_remote_code,
721
+ offload_folder=offload_folder,
722
+ device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
723
+ )
724
+ else:
725
+ with torch.device(device):
726
+ model = model_loader.from_pretrained(
727
+ base_model,
728
+ **model_kwargs
729
+ )
730
+ model = PeftModel.from_pretrained(
731
+ model,
732
+ lora_weights,
733
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
734
+ local_files_only=local_files_only,
735
+ resume_download=resume_download,
736
+ use_auth_token=use_auth_token,
737
+ trust_remote_code=trust_remote_code,
738
+ offload_folder=offload_folder,
739
+ device_map="auto",
740
+ )
741
+ if load_half:
742
+ model.half()
743
+
744
+ # unwind broken decapoda-research config
745
+ if llama_type:
746
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
747
+ model.config.bos_token_id = 1
748
+ model.config.eos_token_id = 2
749
+ if 'gpt2' in base_model.lower():
750
+ # add special tokens that otherwise all share the same id
751
+ tokenizer.add_special_tokens({'bos_token': '<bos>',
752
+ 'eos_token': '<eos>',
753
+ 'pad_token': '<pad>'})
754
+
755
+ if not isinstance(tokenizer, str):
756
+ model.eval()
757
+ if torch.__version__ >= "2" and sys.platform != "win32" and compile_model:
758
+ model = torch.compile(model)
759
+
760
+ if hasattr(config, 'max_seq_len') and isinstance(config.max_seq_len, int):
761
+ tokenizer.model_max_length = config.max_seq_len
762
+ elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int):
763
+ # help automatically limit inputs to generate
764
+ tokenizer.model_max_length = config.max_position_embeddings
765
+ else:
766
+ if verbose:
767
+ print("Could not determine model_max_length, setting to 2048", flush=True)
768
+ tokenizer.model_max_length = 2048
769
+
770
+ return model, tokenizer, device
771
+
772
+
773
+ def pop_unused_model_kwargs(model_kwargs):
774
+ """
775
+ in-place pop unused kwargs that are not dependency-upgrade friendly
776
+ no point passing in False, is default, and helps avoid needing to update requirements for new deps
777
+ :param model_kwargs:
778
+ :return:
779
+ """
780
+ check_list = ['load_in_8bit', 'load_in_4bit']
781
+ for k in check_list:
782
+ if k in model_kwargs and not model_kwargs[k]:
783
+ model_kwargs.pop(k)
784
+
785
+
786
+ def get_score_model(score_model: str = None,
787
+ load_8bit: bool = False,
788
+ load_4bit: bool = False,
789
+ load_half: bool = True,
790
+ infer_devices: bool = True,
791
+ base_model: str = '',
792
+ tokenizer_base_model: str = '',
793
+ lora_weights: str = "",
794
+ gpu_id: int = 0,
795
+
796
+ reward_type: bool = None,
797
+ local_files_only: bool = False,
798
+ resume_download: bool = True,
799
+ use_auth_token: Union[str, bool] = False,
800
+ trust_remote_code: bool = True,
801
+ offload_folder: str = None,
802
+ compile_model: bool = True,
803
+
804
+ verbose: bool = False,
805
+ ):
806
+ if score_model is not None and score_model.strip():
807
+ load_8bit = False
808
+ load_4bit = False
809
+ load_half = False
810
+ base_model = score_model.strip()
811
+ tokenizer_base_model = ''
812
+ lora_weights = ''
813
+ llama_type = False
814
+ compile_model = False
815
+ smodel, stokenizer, sdevice = get_model(reward_type=True,
816
+ **get_kwargs(get_model, exclude_names=['reward_type'], **locals()))
817
+ else:
818
+ smodel, stokenizer, sdevice = None, None, None
819
+ return smodel, stokenizer, sdevice
820
+
821
+
822
+ no_default_param_names = [
823
+ 'instruction',
824
+ 'iinput',
825
+ 'context',
826
+ 'instruction_nochat',
827
+ 'iinput_nochat',
828
+ ]
829
+
830
+ eval_func_param_names = ['instruction',
831
+ 'iinput',
832
+ 'context',
833
+ 'stream_output',
834
+ 'prompt_type',
835
+ 'prompt_dict',
836
+ 'temperature',
837
+ 'top_p',
838
+ 'top_k',
839
+ 'num_beams',
840
+ 'max_new_tokens',
841
+ 'min_new_tokens',
842
+ 'early_stopping',
843
+ 'max_time',
844
+ 'repetition_penalty',
845
+ 'num_return_sequences',
846
+ 'do_sample',
847
+ 'chat',
848
+ 'instruction_nochat',
849
+ 'iinput_nochat',
850
+ 'langchain_mode',
851
+ 'top_k_docs',
852
+ 'chunk',
853
+ 'chunk_size',
854
+ 'document_choice',
855
+ ]
856
+
857
+ # form evaluate defaults for submit_nochat_api
858
+ eval_func_param_names_defaults = eval_func_param_names.copy()
859
+ for k in no_default_param_names:
860
+ if k in eval_func_param_names_defaults:
861
+ eval_func_param_names_defaults.remove(k)
862
+
863
+
864
+ def evaluate_from_str(
865
+ model_state,
866
+ my_db_state,
867
+ # START NOTE: Examples must have same order of parameters
868
+ user_kwargs,
869
+ # END NOTE: Examples must have same order of parameters
870
+ default_kwargs=None,
871
+ src_lang=None,
872
+ tgt_lang=None,
873
+ debug=False,
874
+ concurrency_count=None,
875
+ save_dir=None,
876
+ sanitize_bot_response=True,
877
+ model_state0=None,
878
+ memory_restriction_level=None,
879
+ raise_generate_gpu_exceptions=None,
880
+ chat_context=None,
881
+ lora_weights=None,
882
+ load_db_if_exists=True,
883
+ dbs=None,
884
+ user_path=None,
885
+ detect_user_path_changes_every_query=None,
886
+ use_openai_embedding=None,
887
+ use_openai_model=None,
888
+ hf_embedding_model=None,
889
+ chunk=None,
890
+ chunk_size=None,
891
+ db_type=None,
892
+ n_jobs=None,
893
+ first_para=None,
894
+ text_limit=None,
895
+ verbose=False,
896
+ cli=False,
897
+ ):
898
+ if isinstance(user_kwargs, str):
899
+ user_kwargs = ast.literal_eval(user_kwargs)
900
+ # only used for submit_nochat_api
901
+ user_kwargs['chat'] = False
902
+ user_kwargs['stream_output'] = False
903
+
904
+ assert set(list(default_kwargs.keys())) == set(eval_func_param_names)
905
+ # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
906
+ args_list = [user_kwargs[k] if k in user_kwargs else default_kwargs[k] for k in eval_func_param_names]
907
+
908
+ ret = evaluate(
909
+ model_state,
910
+ my_db_state,
911
+ # START NOTE: Examples must have same order of parameters
912
+ *tuple(args_list),
913
+ # END NOTE: Examples must have same order of parameters
914
+ src_lang=src_lang,
915
+ tgt_lang=tgt_lang,
916
+ debug=debug,
917
+ concurrency_count=concurrency_count,
918
+ save_dir=save_dir,
919
+ sanitize_bot_response=sanitize_bot_response,
920
+ model_state0=model_state0,
921
+ memory_restriction_level=memory_restriction_level,
922
+ raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
923
+ chat_context=chat_context,
924
+ lora_weights=lora_weights,
925
+ load_db_if_exists=load_db_if_exists,
926
+ dbs=dbs,
927
+ user_path=user_path,
928
+ detect_user_path_changes_every_query=detect_user_path_changes_every_query,
929
+ use_openai_embedding=use_openai_embedding,
930
+ use_openai_model=use_openai_model,
931
+ hf_embedding_model=hf_embedding_model,
932
+ db_type=db_type,
933
+ n_jobs=n_jobs,
934
+ first_para=first_para,
935
+ text_limit=text_limit,
936
+ verbose=verbose,
937
+ cli=cli,
938
+ )
939
+ try:
940
+ for ret1 in ret:
941
+ yield ret1
942
+ finally:
943
+ # clear before return, in finally in case GPU OOM exception
944
+ clear_torch_cache()
945
+
946
+
947
+ def evaluate(
948
+ model_state,
949
+ my_db_state,
950
+ # START NOTE: Examples must have same order of parameters
951
+ instruction,
952
+ iinput,
953
+ context,
954
+ stream_output,
955
+ prompt_type,
956
+ prompt_dict,
957
+ temperature,
958
+ top_p,
959
+ top_k,
960
+ num_beams,
961
+ max_new_tokens,
962
+ min_new_tokens,
963
+ early_stopping,
964
+ max_time,
965
+ repetition_penalty,
966
+ num_return_sequences,
967
+ do_sample,
968
+ chat,
969
+ instruction_nochat,
970
+ iinput_nochat,
971
+ langchain_mode,
972
+ top_k_docs,
973
+ chunk,
974
+ chunk_size,
975
+ document_choice,
976
+ # END NOTE: Examples must have same order of parameters
977
+ src_lang=None,
978
+ tgt_lang=None,
979
+ debug=False,
980
+ concurrency_count=None,
981
+ save_dir=None,
982
+ sanitize_bot_response=True,
983
+ model_state0=None,
984
+ memory_restriction_level=None,
985
+ raise_generate_gpu_exceptions=None,
986
+ chat_context=None,
987
+ lora_weights=None,
988
+ load_db_if_exists=True,
989
+ dbs=None,
990
+ user_path=None,
991
+ detect_user_path_changes_every_query=None,
992
+ use_openai_embedding=None,
993
+ use_openai_model=None,
994
+ hf_embedding_model=None,
995
+ db_type=None,
996
+ n_jobs=None,
997
+ first_para=None,
998
+ text_limit=None,
999
+ verbose=False,
1000
+ cli=False,
1001
+ ):
1002
+ # ensure passed these
1003
+ assert concurrency_count is not None
1004
+ assert memory_restriction_level is not None
1005
+ assert raise_generate_gpu_exceptions is not None
1006
+ assert chat_context is not None
1007
+ assert use_openai_embedding is not None
1008
+ assert use_openai_model is not None
1009
+ assert hf_embedding_model is not None
1010
+ assert db_type is not None
1011
+ assert top_k_docs is not None and isinstance(top_k_docs, int)
1012
+ assert chunk is not None and isinstance(chunk, bool)
1013
+ assert chunk_size is not None and isinstance(chunk_size, int)
1014
+ assert n_jobs is not None
1015
+ assert first_para is not None
1016
+
1017
+ if debug:
1018
+ locals_dict = locals().copy()
1019
+ locals_dict.pop('model_state', None)
1020
+ locals_dict.pop('model_state0', None)
1021
+ print(locals_dict)
1022
+
1023
+ no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\nThen start New Conversation"
1024
+
1025
+ if model_state0 is None:
1026
+ # e.g. for no gradio case, set dummy value, else should be set
1027
+ model_state0 = [None, None, None, None]
1028
+
1029
+ if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
1030
+ # try to free-up original model (i.e. list was passed as reference)
1031
+ if model_state0 is not None and model_state0[0] is not None:
1032
+ model_state0[0].cpu()
1033
+ model_state0[0] = None
1034
+ # try to free-up original tokenizer (i.e. list was passed as reference)
1035
+ if model_state0 is not None and model_state0[1] is not None:
1036
+ model_state0[1] = None
1037
+ clear_torch_cache()
1038
+ model, tokenizer, device, base_model = model_state
1039
+ elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
1040
+ assert isinstance(model_state[0], str)
1041
+ model, tokenizer, device, base_model = model_state0
1042
+ else:
1043
+ raise AssertionError(no_model_msg)
1044
+
1045
+ if base_model is None:
1046
+ raise AssertionError(no_model_msg)
1047
+
1048
+ assert base_model.strip(), no_model_msg
1049
+ assert model, "Model is missing"
1050
+ assert tokenizer, "Tokenizer is missing"
1051
+
1052
+ # choose chat or non-chat mode
1053
+ if not chat:
1054
+ instruction = instruction_nochat
1055
+ iinput = iinput_nochat
1056
+
1057
+ if not context:
1058
+ # get hidden context if have one
1059
+ context = get_context(chat_context, prompt_type)
1060
+
1061
+ prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output)
1062
+ data_point = dict(context=context, instruction=instruction, input=iinput)
1063
+ prompt = prompter.generate_prompt(data_point)
1064
+
1065
+ # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
1066
+ assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
1067
+ if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
1068
+ db1 = my_db_state[0]
1069
+ elif dbs is not None and langchain_mode in dbs:
1070
+ db1 = dbs[langchain_mode]
1071
+ else:
1072
+ db1 = None
1073
+ if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in non_hf_types:
1074
+ query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
1075
+ outr = ""
1076
+ # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
1077
+ from gpt_langchain import run_qa_db
1078
+ for r in run_qa_db(query=query,
1079
+ model_name=base_model, model=model, tokenizer=tokenizer,
1080
+ stream_output=stream_output,
1081
+ prompter=prompter,
1082
+ load_db_if_exists=load_db_if_exists,
1083
+ db=db1,
1084
+ user_path=user_path,
1085
+ detect_user_path_changes_every_query=detect_user_path_changes_every_query,
1086
+ max_new_tokens=max_new_tokens,
1087
+ cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
1088
+ use_openai_embedding=use_openai_embedding,
1089
+ use_openai_model=use_openai_model,
1090
+ hf_embedding_model=hf_embedding_model,
1091
+ first_para=first_para,
1092
+ text_limit=text_limit,
1093
+ chunk=chunk,
1094
+ chunk_size=chunk_size,
1095
+ langchain_mode=langchain_mode,
1096
+ document_choice=document_choice,
1097
+ db_type=db_type,
1098
+ top_k_docs=top_k_docs,
1099
+ temperature=temperature,
1100
+ repetition_penalty=repetition_penalty,
1101
+ top_k=top_k,
1102
+ top_p=top_p,
1103
+ prompt_type=prompt_type,
1104
+ prompt_dict=prompt_dict,
1105
+ n_jobs=n_jobs,
1106
+ verbose=verbose,
1107
+ cli=cli,
1108
+ ):
1109
+ outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer
1110
+ yield dict(response=outr, sources=extra)
1111
+ if save_dir:
1112
+ save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
1113
+ if verbose:
1114
+ print(
1115
+ 'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
1116
+ flush=True)
1117
+ if outr or base_model in non_hf_types:
1118
+ # if got no response (e.g. not showing sources and got no sources,
1119
+ # so nothing to give to LLM), then slip through and ask LLM
1120
+ # Or if llama/gptj, then just return since they had no response and can't go down below code path
1121
+ # clear before return, since .then() never done if from API
1122
+ clear_torch_cache()
1123
+ return
1124
+
1125
+ if isinstance(tokenizer, str):
1126
+ # pipeline
1127
+ if tokenizer == "summarization":
1128
+ key = 'summary_text'
1129
+ else:
1130
+ raise RuntimeError("No such task type %s" % tokenizer)
1131
+ # NOTE: uses max_length only
1132
+ yield dict(response=model(prompt, max_length=max_new_tokens)[0][key], sources='')
1133
+
1134
+ if 'mbart-' in base_model.lower():
1135
+ assert src_lang is not None
1136
+ tokenizer.src_lang = languages_covered()[src_lang]
1137
+
1138
+ if chat:
1139
+ # override, ignore user change
1140
+ num_return_sequences = 1
1141
+ stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device)
1142
+ _, _, max_length_tokenize, max_prompt_length = get_cutoffs(memory_restriction_level,
1143
+ model_max_length=tokenizer.model_max_length)
1144
+ prompt = prompt[-max_prompt_length:]
1145
+ inputs = tokenizer(prompt,
1146
+ return_tensors="pt",
1147
+ truncation=True,
1148
+ max_length=max_length_tokenize)
1149
+ if inputs['input_ids'].shape[1] >= max_length_tokenize - 1:
1150
+ print("Cutting off input: %s %s" % (inputs['input_ids'].shape[1], max_length_tokenize), flush=True)
1151
+ if debug and len(inputs["input_ids"]) > 0:
1152
+ print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1153
+ input_ids = inputs["input_ids"].to(device)
1154
+ # CRITICAL LIMIT else will fail
1155
+ max_max_tokens = tokenizer.model_max_length
1156
+ max_input_tokens = max_max_tokens - max_new_tokens
1157
+ input_ids = input_ids[:, -max_input_tokens:]
1158
+ generation_config = GenerationConfig(
1159
+ temperature=float(temperature),
1160
+ top_p=float(top_p),
1161
+ top_k=top_k,
1162
+ num_beams=num_beams,
1163
+ do_sample=do_sample,
1164
+ repetition_penalty=float(repetition_penalty),
1165
+ num_return_sequences=num_return_sequences,
1166
+ renormalize_logits=True,
1167
+ remove_invalid_values=True,
1168
+ )
1169
+
1170
+ gen_kwargs = dict(input_ids=input_ids,
1171
+ generation_config=generation_config,
1172
+ return_dict_in_generate=True,
1173
+ output_scores=True,
1174
+ max_new_tokens=max_new_tokens, # prompt + new
1175
+ min_new_tokens=min_new_tokens, # prompt + new
1176
+ early_stopping=early_stopping, # False, True, "never"
1177
+ max_time=max_time,
1178
+ stopping_criteria=stopping_criteria,
1179
+ )
1180
+ if 'gpt2' in base_model.lower():
1181
+ gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
1182
+ elif 'mbart-' in base_model.lower():
1183
+ assert tgt_lang is not None
1184
+ tgt_lang = languages_covered()[tgt_lang]
1185
+ gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
1186
+ else:
1187
+ gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
1188
+
1189
+ decoder_kwargs = dict(skip_special_tokens=True,
1190
+ clean_up_tokenization_spaces=True)
1191
+
1192
+ decoder = functools.partial(tokenizer.decode,
1193
+ **decoder_kwargs
1194
+ )
1195
+ decoder_raw_kwargs = dict(skip_special_tokens=False,
1196
+ clean_up_tokenization_spaces=True)
1197
+
1198
+ decoder_raw = functools.partial(tokenizer.decode,
1199
+ **decoder_raw_kwargs
1200
+ )
1201
+
1202
+ with torch.no_grad():
1203
+ context_class_cast = NullContext if device == 'cpu' or lora_weights else torch.autocast
1204
+ with context_class_cast(device):
1205
+ # protection for gradio not keeping track of closed users,
1206
+ # else hit bitsandbytes lack of thread safety:
1207
+ # https://github.com/h2oai/h2ogpt/issues/104
1208
+ # but only makes sense if concurrency_count == 1
1209
+ context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
1210
+ if verbose:
1211
+ print('Pre-Generate: %s' % str(datetime.now()), flush=True)
1212
+ decoded_output = None
1213
+ with context_class("generate.lock"):
1214
+ if verbose:
1215
+ print('Generate: %s' % str(datetime.now()), flush=True)
1216
+ # decoded tokenized prompt can deviate from prompt due to special characters
1217
+ inputs_decoded = decoder(input_ids[0])
1218
+ inputs_decoded_raw = decoder_raw(input_ids[0])
1219
+ if inputs_decoded == prompt:
1220
+ # normal
1221
+ pass
1222
+ elif inputs_decoded.lstrip() == prompt.lstrip():
1223
+ # sometimes extra space in front, make prompt same for prompt removal
1224
+ prompt = inputs_decoded
1225
+ elif inputs_decoded_raw == prompt:
1226
+ # some models specify special tokens that are part of normal prompt, so can't skip them
1227
+ inputs_decoded = prompt = inputs_decoded_raw
1228
+ decoder = decoder_raw
1229
+ decoder_kwargs = decoder_raw_kwargs
1230
+ elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ',
1231
+ '') == prompt.replace(
1232
+ '\n', ' ').replace(' ', ''):
1233
+ inputs_decoded = prompt = inputs_decoded_raw
1234
+ decoder = decoder_raw
1235
+ decoder_kwargs = decoder_raw_kwargs
1236
+ else:
1237
+ if verbose:
1238
+ print("WARNING: Special characters in prompt", flush=True)
1239
+ if stream_output:
1240
+ skip_prompt = False
1241
+ streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
1242
+ **decoder_kwargs)
1243
+ gen_kwargs.update(dict(streamer=streamer))
1244
+ target = wrapped_partial(generate_with_exceptions, model.generate,
1245
+ prompt=prompt, inputs_decoded=inputs_decoded,
1246
+ raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
1247
+ **gen_kwargs)
1248
+ bucket = queue.Queue()
1249
+ thread = EThread(target=target, streamer=streamer, bucket=bucket)
1250
+ thread.start()
1251
+ outputs = ""
1252
+ try:
1253
+ for new_text in streamer:
1254
+ if bucket.qsize() > 0 or thread.exc:
1255
+ thread.join()
1256
+ outputs += new_text
1257
+ yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded,
1258
+ sanitize_bot_response=sanitize_bot_response),
1259
+ sources='')
1260
+ except BaseException:
1261
+ # if any exception, raise that exception if was from thread, first
1262
+ if thread.exc:
1263
+ raise thread.exc
1264
+ raise
1265
+ finally:
1266
+ # clear before return, since .then() never done if from API
1267
+ clear_torch_cache()
1268
+ # in case no exception and didn't join with thread yet, then join
1269
+ if not thread.exc:
1270
+ thread.join()
1271
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
1272
+ if thread.exc:
1273
+ raise thread.exc
1274
+ decoded_output = outputs
1275
+ else:
1276
+ try:
1277
+ outputs = model.generate(**gen_kwargs)
1278
+ finally:
1279
+ clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called
1280
+ outputs = [decoder(s) for s in outputs.sequences]
1281
+ yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded,
1282
+ sanitize_bot_response=sanitize_bot_response), sources='')
1283
+ if outputs and len(outputs) >= 1:
1284
+ decoded_output = prompt + outputs[0]
1285
+ if save_dir and decoded_output:
1286
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
1287
+ if verbose:
1288
+ print('Post-Generate: %s decoded_output: %s' % (
1289
+ str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
1290
+
1291
+
1292
+ inputs_list_names = list(inspect.signature(evaluate).parameters)
1293
+ state_names = ['model_state', 'my_db_state']
1294
+ inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
1295
+
1296
+
1297
+ def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=2048):
1298
+ # help to avoid errors like:
1299
+ # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
1300
+ # RuntimeError: expected scalar type Half but found Float
1301
+ # with - 256
1302
+ if memory_restriction_level > 0:
1303
+ max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
1304
+ else:
1305
+ max_length_tokenize = model_max_length - 256
1306
+ cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1307
+ output_smallest = 30 * 4
1308
+ max_prompt_length = cutoff_len - output_smallest
1309
+
1310
+ if for_context:
1311
+ # then lower even more to avoid later chop, since just estimate tokens in context bot
1312
+ max_prompt_length = max(64, int(max_prompt_length * 0.8))
1313
+
1314
+ return cutoff_len, output_smallest, max_length_tokenize, max_prompt_length
1315
+
1316
+
1317
+ class H2OTextIteratorStreamer(TextIteratorStreamer):
1318
+ """
1319
+ normally, timeout required for now to handle exceptions, else get()
1320
+ but with H2O version of TextIteratorStreamer, loop over block to handle
1321
+ """
1322
+
1323
+ def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None,
1324
+ block=True, **decode_kwargs):
1325
+ super().__init__(tokenizer, skip_prompt, **decode_kwargs)
1326
+ self.text_queue = queue.Queue()
1327
+ self.stop_signal = None
1328
+ self.do_stop = False
1329
+ self.timeout = timeout
1330
+ self.block = block
1331
+
1332
+ def on_finalized_text(self, text: str, stream_end: bool = False):
1333
+ """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
1334
+ self.text_queue.put(text, timeout=self.timeout)
1335
+ if stream_end:
1336
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
1337
+
1338
+ def __iter__(self):
1339
+ return self
1340
+
1341
+ def __next__(self):
1342
+ while True:
1343
+ try:
1344
+ value = self.stop_signal # value looks unused in pycharm, not true
1345
+ if self.do_stop:
1346
+ print("hit stop", flush=True)
1347
+ # could raise or break, maybe best to raise and make parent see if any exception in thread
1348
+ raise StopIteration()
1349
+ # break
1350
+ value = self.text_queue.get(block=self.block, timeout=self.timeout)
1351
+ break
1352
+ except queue.Empty:
1353
+ time.sleep(0.01)
1354
+ if value == self.stop_signal:
1355
+ raise StopIteration()
1356
+ else:
1357
+ return value
1358
+
1359
+
1360
+ def generate_with_exceptions(func, *args, prompt='', inputs_decoded='', raise_generate_gpu_exceptions=True, **kwargs):
1361
+ try:
1362
+ func(*args, **kwargs)
1363
+ except torch.cuda.OutOfMemoryError as e:
1364
+ print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1365
+ flush=True)
1366
+ if 'input_ids' in kwargs:
1367
+ if kwargs['input_ids'] is not None:
1368
+ kwargs['input_ids'].cpu()
1369
+ kwargs['input_ids'] = None
1370
+ traceback.print_exc()
1371
+ clear_torch_cache()
1372
+ return
1373
+ except (Exception, RuntimeError) as e:
1374
+ if 'Expected all tensors to be on the same device' in str(e) or \
1375
+ 'expected scalar type Half but found Float' in str(e) or \
1376
+ 'probability tensor contains either' in str(e) or \
1377
+ 'cublasLt ran into an error!' in str(e) or \
1378
+ 'mat1 and mat2 shapes cannot be multiplied' in str(e):
1379
+ print(
1380
+ "GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1381
+ flush=True)
1382
+ traceback.print_exc()
1383
+ clear_torch_cache()
1384
+ if raise_generate_gpu_exceptions:
1385
+ raise
1386
+ return
1387
+ else:
1388
+ clear_torch_cache()
1389
+ if raise_generate_gpu_exceptions:
1390
+ raise
1391
+
1392
+
1393
+ def get_generate_params(model_lower, chat,
1394
+ stream_output, show_examples,
1395
+ prompt_type, prompt_dict,
1396
+ temperature, top_p, top_k, num_beams,
1397
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
1398
+ repetition_penalty, num_return_sequences,
1399
+ do_sample,
1400
+ top_k_docs, chunk, chunk_size,
1401
+ verbose):
1402
+ use_defaults = False
1403
+ use_default_examples = True
1404
+ examples = []
1405
+ task_info = 'LLM'
1406
+ if model_lower:
1407
+ print(f"Using Model {model_lower}", flush=True)
1408
+ else:
1409
+ print("No model defined yet", flush=True)
1410
+
1411
+ min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
1412
+ early_stopping = early_stopping if early_stopping is not None else False
1413
+ max_time_defaults = 60 * 3
1414
+ max_time = max_time if max_time is not None else max_time_defaults
1415
+
1416
+ if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1417
+ prompt_type = inv_prompt_type_to_model_lower[model_lower]
1418
+ if verbose:
1419
+ print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
1420
+
1421
+ # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
1422
+ if show_examples is None:
1423
+ if chat:
1424
+ show_examples = False
1425
+ else:
1426
+ show_examples = True
1427
+
1428
+ summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
1429
+ Philipp: Sure you can use the new Hugging Face Deep Learning Container.
1430
+ Jeff: ok.
1431
+ Jeff: and how can I get started?
1432
+ Jeff: where can I find documentation?
1433
+ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
1434
+
1435
+ use_placeholder_instruction_as_example = False
1436
+ if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
1437
+ placeholder_instruction = summarize_example1
1438
+ placeholder_input = ""
1439
+ use_defaults = True
1440
+ use_default_examples = False
1441
+ use_placeholder_instruction_as_example = True
1442
+ task_info = "Summarization"
1443
+ elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
1444
+ placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
1445
+ placeholder_input = ""
1446
+ use_defaults = True
1447
+ use_default_examples = True
1448
+ task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
1449
+ elif 'mbart-' in model_lower:
1450
+ placeholder_instruction = "The girl has long hair."
1451
+ placeholder_input = ""
1452
+ use_defaults = True
1453
+ use_default_examples = False
1454
+ use_placeholder_instruction_as_example = True
1455
+ elif 'gpt2' in model_lower:
1456
+ placeholder_instruction = "The sky is"
1457
+ placeholder_input = ""
1458
+ prompt_type = prompt_type or 'plain'
1459
+ use_default_examples = True # some will be odd "continuations" but can be ok
1460
+ use_placeholder_instruction_as_example = True
1461
+ task_info = "Auto-complete phrase, code, etc."
1462
+ use_defaults = True
1463
+ else:
1464
+ if chat:
1465
+ placeholder_instruction = "Enter a question or imperative."
1466
+ else:
1467
+ placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
1468
+ placeholder_input = ""
1469
+ if model_lower:
1470
+ # default is plain, because might relly upon trust_remote_code to handle prompting
1471
+ prompt_type = prompt_type or 'plain'
1472
+ else:
1473
+ prompt_type = ''
1474
+ task_info = "No task"
1475
+ if prompt_type == 'instruct':
1476
+ task_info = "Answer question or follow imperative as instruction with optionally input."
1477
+ elif prompt_type == 'plain':
1478
+ task_info = "Auto-complete phrase, code, etc."
1479
+ elif prompt_type == 'human_bot':
1480
+ if chat:
1481
+ task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
1482
+ else:
1483
+ task_info = "Ask question/imperative (input concatenated with instruction)"
1484
+
1485
+ # revert to plain if still nothing
1486
+ prompt_type = prompt_type or 'plain'
1487
+ if use_defaults:
1488
+ temperature = 1.0 if temperature is None else temperature
1489
+ top_p = 1.0 if top_p is None else top_p
1490
+ top_k = 40 if top_k is None else top_k
1491
+ num_beams = num_beams or 1
1492
+ max_new_tokens = max_new_tokens or 128
1493
+ repetition_penalty = repetition_penalty or 1.07
1494
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1495
+ do_sample = False if do_sample is None else do_sample
1496
+ else:
1497
+ temperature = 0.1 if temperature is None else temperature
1498
+ top_p = 0.75 if top_p is None else top_p
1499
+ top_k = 40 if top_k is None else top_k
1500
+ num_beams = num_beams or 1
1501
+ max_new_tokens = max_new_tokens or 256
1502
+ repetition_penalty = repetition_penalty or 1.07
1503
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1504
+ do_sample = False if do_sample is None else do_sample
1505
+ # doesn't include chat, instruction_nochat, iinput_nochat, added later
1506
+ params_list = ["",
1507
+ stream_output,
1508
+ prompt_type, prompt_dict,
1509
+ temperature, top_p, top_k, num_beams,
1510
+ max_new_tokens, min_new_tokens,
1511
+ early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
1512
+
1513
+ if use_placeholder_instruction_as_example:
1514
+ examples += [[placeholder_instruction, ''] + params_list]
1515
+
1516
+ if use_default_examples:
1517
+ examples += [
1518
+ ["Translate English to French", "Good morning"] + params_list,
1519
+ ["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
1520
+ ["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
1521
+ [
1522
+ "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
1523
+ ''] + params_list,
1524
+ ['Translate to German: My name is Arthur', ''] + params_list,
1525
+ ["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
1526
+ ['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
1527
+ ''] + params_list,
1528
+ ['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
1529
+ ['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
1530
+ ["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
1531
+ [
1532
+ "Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
1533
+ ''] + params_list,
1534
+ ['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
1535
+ [
1536
+ 'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
1537
+ ''] + params_list,
1538
+ ["""def area_of_rectangle(a: float, b: float):
1539
+ \"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
1540
+ ["""# a function in native python:
1541
+ def mean(a):
1542
+ return sum(a)/len(a)
1543
+
1544
+ # the same function using numpy:
1545
+ import numpy as np
1546
+ def mean(a):""", ''] + params_list,
1547
+ ["""X = np.random.randn(100, 100)
1548
+ y = np.random.randint(0, 1, 100)
1549
+
1550
+ # fit random forest classifier with 20 estimators""", ''] + params_list,
1551
+ ]
1552
+ # add summary example
1553
+ examples += [
1554
+ [summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else ''] + params_list]
1555
+
1556
+ src_lang = "English"
1557
+ tgt_lang = "Russian"
1558
+
1559
+ # move to correct position
1560
+ for example in examples:
1561
+ example += [chat, '', '', 'Disabled', top_k_docs, chunk, chunk_size, [DocumentChoices.All_Relevant.name]]
1562
+ # adjust examples if non-chat mode
1563
+ if not chat:
1564
+ example[eval_func_param_names.index('instruction_nochat')] = example[
1565
+ eval_func_param_names.index('instruction')]
1566
+ example[eval_func_param_names.index('instruction')] = ''
1567
+
1568
+ example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
1569
+ example[eval_func_param_names.index('iinput')] = ''
1570
+ assert len(example) == len(eval_func_param_names), "Wrong example: %s %s" % (
1571
+ len(example), len(eval_func_param_names))
1572
+
1573
+ if prompt_type == PromptType.custom.name and not prompt_dict:
1574
+ raise ValueError("Unexpected to get non-empty prompt_dict=%s for prompt_type=%s" % (prompt_dict, prompt_type))
1575
+
1576
+ # get prompt_dict from prompt_type, so user can see in UI etc., or for custom do nothing except check format
1577
+ prompt_dict, error0 = get_prompt(prompt_type, prompt_dict,
1578
+ chat=False, context='', reduced=False, return_dict=True)
1579
+ if error0:
1580
+ raise RuntimeError("Prompt wrong: %s" % error0)
1581
+
1582
+ return placeholder_instruction, placeholder_input, \
1583
+ stream_output, show_examples, \
1584
+ prompt_type, prompt_dict, \
1585
+ temperature, top_p, top_k, num_beams, \
1586
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
1587
+ repetition_penalty, num_return_sequences, \
1588
+ do_sample, \
1589
+ src_lang, tgt_lang, \
1590
+ examples, \
1591
+ task_info
1592
+
1593
+
1594
+ def languages_covered():
1595
+ # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
1596
+ covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
1597
+ covered = covered.split(', ')
1598
+ covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
1599
+ return covered
1600
+
1601
+
1602
+ def get_context(chat_context, prompt_type):
1603
+ if chat_context and prompt_type == 'human_bot':
1604
+ context0 = """<bot>: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand.
1605
+ <human>: I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed."""
1606
+ else:
1607
+ context0 = ''
1608
+ return context0
1609
+
1610
+
1611
+ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len):
1612
+ question = question[-cutoff_len:]
1613
+ answer = answer[-cutoff_len:]
1614
+
1615
+ inputs = stokenizer(question, answer,
1616
+ return_tensors="pt",
1617
+ truncation=True,
1618
+ max_length=max_length_tokenize).to(smodel.device)
1619
+ try:
1620
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
1621
+ except torch.cuda.OutOfMemoryError as e:
1622
+ print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
1623
+ del inputs
1624
+ traceback.print_exc()
1625
+ clear_torch_cache()
1626
+ return 'Response Score: GPU OOM'
1627
+ except (Exception, RuntimeError) as e:
1628
+ if 'Expected all tensors to be on the same device' in str(e) or \
1629
+ 'expected scalar type Half but found Float' in str(e) or \
1630
+ 'probability tensor contains either' in str(e) or \
1631
+ 'cublasLt ran into an error!' in str(e):
1632
+ print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
1633
+ flush=True)
1634
+ traceback.print_exc()
1635
+ clear_torch_cache()
1636
+ return 'Response Score: GPU Error'
1637
+ else:
1638
+ raise
1639
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
1640
+ return score
1641
+
1642
+
1643
+ def check_locals(**kwargs):
1644
+ # ensure everything in evaluate is here
1645
+ can_skip_because_locally_generated = no_default_param_names + [
1646
+ # get_model:
1647
+ 'reward_type'
1648
+ ]
1649
+ for k in eval_func_param_names:
1650
+ if k in can_skip_because_locally_generated:
1651
+ continue
1652
+ assert k in kwargs, "Missing %s" % k
1653
+ for k in inputs_kwargs_list:
1654
+ if k in can_skip_because_locally_generated:
1655
+ continue
1656
+ assert k in kwargs, "Missing %s" % k
1657
+
1658
+ for k in list(inspect.signature(get_model).parameters):
1659
+ if k in can_skip_because_locally_generated:
1660
+ continue
1661
+ assert k in kwargs, "Missing %s" % k
1662
+
1663
+
1664
+ def get_max_max_new_tokens(model_state, **kwargs):
1665
+ if kwargs['max_new_tokens'] and kwargs['user_set_max_new_tokens']:
1666
+ max_max_new_tokens = kwargs['max_new_tokens']
1667
+ elif kwargs['memory_restriction_level'] == 1:
1668
+ max_max_new_tokens = 768
1669
+ elif kwargs['memory_restriction_level'] == 2:
1670
+ max_max_new_tokens = 512
1671
+ elif kwargs['memory_restriction_level'] >= 3:
1672
+ max_max_new_tokens = 256
1673
+ else:
1674
+ if not isinstance(model_state[1], str):
1675
+ max_max_new_tokens = model_state[1].model_max_length
1676
+ else:
1677
+ # FIXME: Need to update after new model loaded, so user can control with slider
1678
+ max_max_new_tokens = 2048
1679
+ return max_max_new_tokens
1680
+
1681
+
1682
+ if __name__ == "__main__":
1683
+ """
1684
+ Examples:
1685
+
1686
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
1687
+ python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
1688
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
1689
+
1690
+ # generate without lora weights, no prompt
1691
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
1692
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
1693
+
1694
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
1695
+ # OpenChatKit settings:
1696
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
1697
+
1698
+ python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
1699
+ python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
1700
+ python generate.py --base_model='philschmid/bart-large-cnn-samsum'
1701
+ python generate.py --base_model='philschmid/flan-t5-base-samsum'
1702
+ python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
1703
+
1704
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
1705
+
1706
+ must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
1707
+ can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
1708
+ python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
1709
+
1710
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
1711
+ """
1712
+ fire.Fire(main)
gpt4all_llm.py DELETED
@@ -1 +0,0 @@
1
- ../../gpt4all_llm.py
 
 
gpt4all_llm.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ import sys
4
+ from typing import Dict, Any, Optional, List
5
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
6
+ from pydantic import root_validator
7
+ from langchain.llms import gpt4all
8
+ from dotenv import dotenv_values
9
+
10
+
11
+ class FakeTokenizer:
12
+ model_max_length = 2048
13
+
14
+ def encode(self, x, *args, **kwargs):
15
+ return dict(input_ids=[x])
16
+
17
+ def decode(self, x, *args, **kwargs):
18
+ return x
19
+
20
+ def __call__(self, x, *args, **kwargs):
21
+ return self.encode(x, *args, **kwargs)
22
+
23
+
24
+ def get_model_tokenizer_gpt4all(base_model, **kwargs):
25
+ # defaults (some of these are generation parameters, so need to be passed in at generation time)
26
+ model_kwargs = dict(n_threads=os.cpu_count() // 2,
27
+ temp=kwargs.get('temperature', 0.2),
28
+ top_p=kwargs.get('top_p', 0.75),
29
+ top_k=kwargs.get('top_k', 40),
30
+ n_ctx=2048 - 256)
31
+ env_gpt4all_file = ".env_gpt4all"
32
+ model_kwargs.update(dotenv_values(env_gpt4all_file))
33
+
34
+ if base_model == "llama":
35
+ if 'model_path_llama' not in model_kwargs:
36
+ raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
37
+ model_path = model_kwargs.pop('model_path_llama')
38
+ # FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
39
+ from llama_cpp import Llama
40
+ # llama sets some things at init model time, not generation time
41
+ func_names = list(inspect.signature(Llama.__init__).parameters)
42
+ model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
43
+ model_kwargs['n_ctx'] = int(model_kwargs['n_ctx'])
44
+ model = Llama(model_path=model_path, **model_kwargs)
45
+ elif base_model in "gpt4all_llama":
46
+ if 'model_name_gpt4all_llama' not in model_kwargs and 'model_path_gpt4all_llama' not in model_kwargs:
47
+ raise ValueError("No model_name_gpt4all_llama or model_path_gpt4all_llama in %s" % env_gpt4all_file)
48
+ model_name = model_kwargs.pop('model_name_gpt4all_llama')
49
+ model_type = 'llama'
50
+ from gpt4all import GPT4All as GPT4AllModel
51
+ model = GPT4AllModel(model_name=model_name, model_type=model_type)
52
+ elif base_model in "gptj":
53
+ if 'model_name_gptj' not in model_kwargs and 'model_path_gptj' not in model_kwargs:
54
+ raise ValueError("No model_name_gpt4j or model_path_gpt4j in %s" % env_gpt4all_file)
55
+ model_name = model_kwargs.pop('model_name_gptj')
56
+ model_type = 'gptj'
57
+ from gpt4all import GPT4All as GPT4AllModel
58
+ model = GPT4AllModel(model_name=model_name, model_type=model_type)
59
+ else:
60
+ raise ValueError("No such base_model %s" % base_model)
61
+ return model, FakeTokenizer(), 'cpu'
62
+
63
+
64
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
65
+
66
+
67
+ class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
68
+
69
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
70
+ """Run on new LLM token. Only available when streaming is enabled."""
71
+ # streaming to std already occurs without this
72
+ # sys.stdout.write(token)
73
+ # sys.stdout.flush()
74
+ pass
75
+
76
+
77
+ def get_model_kwargs(env_kwargs, default_kwargs, cls):
78
+ # default from class
79
+ model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
80
+ # from our defaults
81
+ model_kwargs.update(default_kwargs)
82
+ # from user defaults
83
+ model_kwargs.update(env_kwargs)
84
+ # ensure only valid keys
85
+ func_names = list(inspect.signature(cls).parameters)
86
+ model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
87
+ return model_kwargs
88
+
89
+
90
+ def get_llm_gpt4all(model_name,
91
+ model=None,
92
+ max_new_tokens=256,
93
+ temperature=0.1,
94
+ repetition_penalty=1.0,
95
+ top_k=40,
96
+ top_p=0.7,
97
+ verbose=False):
98
+ env_gpt4all_file = ".env_gpt4all"
99
+ env_kwargs = dotenv_values(env_gpt4all_file)
100
+ callbacks = [H2OStreamingStdOutCallbackHandler()]
101
+ n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
102
+ default_kwargs = dict(context_erase=0.5,
103
+ n_batch=1,
104
+ n_ctx=n_ctx,
105
+ n_predict=max_new_tokens,
106
+ repeat_last_n=64 if repetition_penalty != 1.0 else 0,
107
+ repeat_penalty=repetition_penalty,
108
+ temp=temperature,
109
+ temperature=temperature,
110
+ top_k=top_k,
111
+ top_p=top_p,
112
+ use_mlock=True,
113
+ verbose=verbose)
114
+ if model_name == 'llama':
115
+ cls = H2OLlamaCpp
116
+ model_path = env_kwargs.pop('model_path_llama') if model is None else model
117
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
118
+ model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
119
+ llm = cls(**model_kwargs)
120
+ llm.client.verbose = verbose
121
+ elif model_name == 'gpt4all_llama':
122
+ cls = H2OGPT4All
123
+ model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
124
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
125
+ model_kwargs.update(dict(model=model_path, backend='llama', callbacks=callbacks))
126
+ llm = cls(**model_kwargs)
127
+ elif model_name == 'gptj':
128
+ cls = H2OGPT4All
129
+ model_path = env_kwargs.pop('model_path_gptj') if model is None else model
130
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
131
+ model_kwargs.update(dict(model=model_path, backend='gptj', callbacks=callbacks))
132
+ llm = cls(**model_kwargs)
133
+ else:
134
+ raise RuntimeError("No such model_name %s" % model_name)
135
+ return llm
136
+
137
+
138
+ class H2OGPT4All(gpt4all.GPT4All):
139
+ model: Any
140
+ """Path to the pre-trained GPT4All model file."""
141
+
142
+ @root_validator()
143
+ def validate_environment(cls, values: Dict) -> Dict:
144
+ """Validate that the python package exists in the environment."""
145
+ try:
146
+ if isinstance(values["model"], str):
147
+ from gpt4all import GPT4All as GPT4AllModel
148
+
149
+ full_path = values["model"]
150
+ model_path, delimiter, model_name = full_path.rpartition("/")
151
+ model_path += delimiter
152
+
153
+ values["client"] = GPT4AllModel(
154
+ model_name=model_name,
155
+ model_path=model_path or None,
156
+ model_type=values["backend"],
157
+ allow_download=False,
158
+ )
159
+ else:
160
+ values["client"] = values["model"]
161
+ values["backend"] = values["client"].model.model_type
162
+
163
+ except ImportError:
164
+ raise ValueError(
165
+ "Could not import gpt4all python package. "
166
+ "Please install it with `pip install gpt4all`."
167
+ )
168
+ return values
169
+
170
+ def _call(
171
+ self,
172
+ prompt: str,
173
+ stop: Optional[List[str]] = None,
174
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
175
+ ) -> str:
176
+ # Roughly 4 chars per token if natural language
177
+ prompt = prompt[-self.n_ctx * 4:]
178
+ verbose = False
179
+ if verbose:
180
+ print("_call prompt: %s" % prompt, flush=True)
181
+ return super()._call(prompt, stop=stop, run_manager=run_manager)
182
+
183
+
184
+ from langchain.llms import LlamaCpp
185
+
186
+
187
+ class H2OLlamaCpp(LlamaCpp):
188
+ model_path: Any
189
+ """Path to the pre-trained GPT4All model file."""
190
+
191
+ @root_validator()
192
+ def validate_environment(cls, values: Dict) -> Dict:
193
+ """Validate that llama-cpp-python library is installed."""
194
+ if isinstance(values["model_path"], str):
195
+ model_path = values["model_path"]
196
+ model_param_names = [
197
+ "lora_path",
198
+ "lora_base",
199
+ "n_ctx",
200
+ "n_parts",
201
+ "seed",
202
+ "f16_kv",
203
+ "logits_all",
204
+ "vocab_only",
205
+ "use_mlock",
206
+ "n_threads",
207
+ "n_batch",
208
+ "use_mmap",
209
+ "last_n_tokens_size",
210
+ ]
211
+ model_params = {k: values[k] for k in model_param_names}
212
+ # For backwards compatibility, only include if non-null.
213
+ if values["n_gpu_layers"] is not None:
214
+ model_params["n_gpu_layers"] = values["n_gpu_layers"]
215
+
216
+ try:
217
+ from llama_cpp import Llama
218
+
219
+ values["client"] = Llama(model_path, **model_params)
220
+ except ImportError:
221
+ raise ModuleNotFoundError(
222
+ "Could not import llama-cpp-python library. "
223
+ "Please install the llama-cpp-python library to "
224
+ "use this embedding model: pip install llama-cpp-python"
225
+ )
226
+ except Exception as e:
227
+ raise ValueError(
228
+ f"Could not load Llama model from path: {model_path}. "
229
+ f"Received error {e}"
230
+ )
231
+ else:
232
+ values["client"] = values["model_path"]
233
+ return values
234
+
235
+ def _call(
236
+ self,
237
+ prompt: str,
238
+ stop: Optional[List[str]] = None,
239
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
240
+ ) -> str:
241
+ verbose = False
242
+ # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
243
+ # still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
244
+ prompt = prompt[-self.n_ctx * 4:]
245
+ prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
246
+ num_prompt_tokens = len(prompt_tokens)
247
+ if num_prompt_tokens > self.n_ctx:
248
+ # conservative by using int()
249
+ chars_per_token = int(len(prompt) / num_prompt_tokens)
250
+ prompt = prompt[-self.n_ctx * chars_per_token:]
251
+ if verbose:
252
+ print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
253
+ prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
254
+ num_prompt_tokens2 = len(prompt_tokens2)
255
+ print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
256
+ if verbose:
257
+ print("_call prompt: %s" % prompt, flush=True)
258
+ return super()._call(prompt, stop=stop, run_manager=run_manager)
gpt_langchain.py DELETED
@@ -1 +0,0 @@
1
- ../../gpt_langchain.py
 
 
gpt_langchain.py ADDED
@@ -0,0 +1,1633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import inspect
3
+ import os
4
+ import pathlib
5
+ import pickle
6
+ import queue
7
+ import random
8
+ import shutil
9
+ import subprocess
10
+ import sys
11
+ import tempfile
12
+ import traceback
13
+ import uuid
14
+ import zipfile
15
+ from collections import defaultdict
16
+ from datetime import datetime
17
+ from functools import reduce
18
+ from operator import concat
19
+
20
+ from joblib import Parallel, delayed
21
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
22
+ from tqdm import tqdm
23
+
24
+ from enums import DocumentChoices
25
+ from prompter import non_hf_types, PromptType
26
+ from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
27
+ get_device, ProgressParallel, remove, hash_file, clear_torch_cache
28
+
29
+ import_matplotlib()
30
+
31
+ import numpy as np
32
+ import pandas as pd
33
+ import requests
34
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
35
+ # , GCSDirectoryLoader, GCSFileLoader
36
+ # , OutlookMessageLoader # GPL3
37
+ # ImageCaptionLoader, # use our own wrapper
38
+ # ReadTheDocsLoader, # no special file, some path, so have to give as special option
39
+ from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
40
+ UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
41
+ EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
42
+ UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
43
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
44
+ from langchain.chains.question_answering import load_qa_chain
45
+ from langchain.docstore.document import Document
46
+ from langchain import PromptTemplate
47
+ from langchain.vectorstores import Chroma
48
+
49
+
50
+ def get_db(sources, use_openai_embedding=False, db_type='faiss',
51
+ persist_directory="db_dir", load_db_if_exists=True,
52
+ langchain_mode='notset',
53
+ collection_name=None,
54
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
55
+ if not sources:
56
+ return None
57
+
58
+ # get embedding model
59
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
60
+ assert collection_name is not None or langchain_mode != 'notset'
61
+ if collection_name is None:
62
+ collection_name = langchain_mode.replace(' ', '_')
63
+
64
+ # Create vector database
65
+ if db_type == 'faiss':
66
+ from langchain.vectorstores import FAISS
67
+ db = FAISS.from_documents(sources, embedding)
68
+ elif db_type == 'weaviate':
69
+ import weaviate
70
+ from weaviate.embedded import EmbeddedOptions
71
+ from langchain.vectorstores import Weaviate
72
+
73
+ if os.getenv('WEAVIATE_URL', None):
74
+ client = _create_local_weaviate_client()
75
+ else:
76
+ client = weaviate.Client(
77
+ embedded_options=EmbeddedOptions()
78
+ )
79
+ index_name = collection_name.capitalize()
80
+ db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
81
+ index_name=index_name)
82
+ elif db_type == 'chroma':
83
+ assert persist_directory is not None
84
+ os.makedirs(persist_directory, exist_ok=True)
85
+
86
+ # see if already actually have persistent db, and deal with possible changes in embedding
87
+ db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
88
+ hf_embedding_model, verbose=False)
89
+ if db is None:
90
+ db = Chroma.from_documents(documents=sources,
91
+ embedding=embedding,
92
+ persist_directory=persist_directory,
93
+ collection_name=collection_name,
94
+ anonymized_telemetry=False)
95
+ db.persist()
96
+ clear_embedding(db)
97
+ save_embed(db, use_openai_embedding, hf_embedding_model)
98
+ else:
99
+ # then just add
100
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
101
+ use_openai_embedding=use_openai_embedding,
102
+ hf_embedding_model=hf_embedding_model)
103
+ else:
104
+ raise RuntimeError("No such db_type=%s" % db_type)
105
+
106
+ return db
107
+
108
+
109
+ def _get_unique_sources_in_weaviate(db):
110
+ batch_size = 100
111
+ id_source_list = []
112
+ result = db._client.data_object.get(class_name=db._index_name, limit=batch_size)
113
+
114
+ while result['objects']:
115
+ id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']]
116
+ last_id = id_source_list[-1][0]
117
+ result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id)
118
+
119
+ unique_sources = {source for _, source in id_source_list}
120
+ return unique_sources
121
+
122
+
123
+ def add_to_db(db, sources, db_type='faiss',
124
+ avoid_dup_by_file=False,
125
+ avoid_dup_by_content=True,
126
+ use_openai_embedding=False,
127
+ hf_embedding_model=None):
128
+ assert hf_embedding_model is not None
129
+ num_new_sources = len(sources)
130
+ if not sources:
131
+ return db, num_new_sources, []
132
+ if db_type == 'faiss':
133
+ db.add_documents(sources)
134
+ elif db_type == 'weaviate':
135
+ # FIXME: only control by file name, not hash yet
136
+ if avoid_dup_by_file or avoid_dup_by_content:
137
+ unique_sources = _get_unique_sources_in_weaviate(db)
138
+ sources = [x for x in sources if x.metadata['source'] not in unique_sources]
139
+ num_new_sources = len(sources)
140
+ if num_new_sources == 0:
141
+ return db, num_new_sources, []
142
+ db.add_documents(documents=sources)
143
+ elif db_type == 'chroma':
144
+ collection = db.get()
145
+ # files we already have:
146
+ metadata_files = set([x['source'] for x in collection['metadatas']])
147
+ if avoid_dup_by_file:
148
+ # Too weak in case file changed content, assume parent shouldn't pass true for this for now
149
+ raise RuntimeError("Not desired code path")
150
+ sources = [x for x in sources if x.metadata['source'] not in metadata_files]
151
+ if avoid_dup_by_content:
152
+ # look at hash, instead of page_content
153
+ # migration: If no hash previously, avoid updating,
154
+ # since don't know if need to update and may be expensive to redo all unhashed files
155
+ metadata_hash_ids = set(
156
+ [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
157
+ # avoid sources with same hash
158
+ sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
159
+ # get new file names that match existing file names. delete existing files we are overridding
160
+ dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
161
+ print("Removing %s duplicate files from db because ingesting those as new documents" % len(
162
+ dup_metadata_files), flush=True)
163
+ client_collection = db._client.get_collection(name=db._collection.name,
164
+ embedding_function=db._collection._embedding_function)
165
+ for dup_file in dup_metadata_files:
166
+ dup_file_meta = dict(source=dup_file)
167
+ try:
168
+ client_collection.delete(where=dup_file_meta)
169
+ except KeyError:
170
+ pass
171
+ num_new_sources = len(sources)
172
+ if num_new_sources == 0:
173
+ return db, num_new_sources, []
174
+ db.add_documents(documents=sources)
175
+ db.persist()
176
+ clear_embedding(db)
177
+ save_embed(db, use_openai_embedding, hf_embedding_model)
178
+ else:
179
+ raise RuntimeError("No such db_type=%s" % db_type)
180
+
181
+ new_sources_metadata = [x.metadata for x in sources]
182
+
183
+ return db, num_new_sources, new_sources_metadata
184
+
185
+
186
+ def create_or_update_db(db_type, persist_directory, collection_name,
187
+ sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model):
188
+ if db_type == 'weaviate':
189
+ import weaviate
190
+ from weaviate.embedded import EmbeddedOptions
191
+
192
+ if os.getenv('WEAVIATE_URL', None):
193
+ client = _create_local_weaviate_client()
194
+ else:
195
+ client = weaviate.Client(
196
+ embedded_options=EmbeddedOptions()
197
+ )
198
+
199
+ index_name = collection_name.replace(' ', '_').capitalize()
200
+ if client.schema.exists(index_name) and not add_if_exists:
201
+ client.schema.delete_class(index_name)
202
+ if verbose:
203
+ print("Removing %s" % index_name, flush=True)
204
+ elif db_type == 'chroma':
205
+ if not os.path.isdir(persist_directory) or not add_if_exists:
206
+ if os.path.isdir(persist_directory):
207
+ if verbose:
208
+ print("Removing %s" % persist_directory, flush=True)
209
+ remove(persist_directory)
210
+ if verbose:
211
+ print("Generating db", flush=True)
212
+
213
+ if not add_if_exists:
214
+ if verbose:
215
+ print("Generating db", flush=True)
216
+ else:
217
+ if verbose:
218
+ print("Loading and updating db", flush=True)
219
+
220
+ db = get_db(sources,
221
+ use_openai_embedding=use_openai_embedding,
222
+ db_type=db_type,
223
+ persist_directory=persist_directory,
224
+ langchain_mode=collection_name,
225
+ hf_embedding_model=hf_embedding_model)
226
+
227
+ return db
228
+
229
+
230
+ def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
231
+ # Get embedding model
232
+ if use_openai_embedding:
233
+ assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
234
+ from langchain.embeddings import OpenAIEmbeddings
235
+ embedding = OpenAIEmbeddings()
236
+ else:
237
+ # to ensure can fork without deadlock
238
+ from langchain.embeddings import HuggingFaceEmbeddings
239
+
240
+ device, torch_dtype, context_class = get_device_dtype()
241
+ model_kwargs = dict(device=device)
242
+ if 'instructor' in hf_embedding_model:
243
+ encode_kwargs = {'normalize_embeddings': True}
244
+ embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model,
245
+ model_kwargs=model_kwargs,
246
+ encode_kwargs=encode_kwargs)
247
+ else:
248
+ embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
249
+ return embedding
250
+
251
+
252
+ def get_answer_from_sources(chain, sources, question):
253
+ return chain(
254
+ {
255
+ "input_documents": sources,
256
+ "question": question,
257
+ },
258
+ return_only_outputs=True,
259
+ )["output_text"]
260
+
261
+
262
+ def get_llm(use_openai_model=False, model_name=None, model=None,
263
+ tokenizer=None, stream_output=False,
264
+ max_new_tokens=256,
265
+ temperature=0.1,
266
+ repetition_penalty=1.0,
267
+ top_k=40,
268
+ top_p=0.7,
269
+ prompt_type=None,
270
+ prompt_dict=None,
271
+ prompter=None,
272
+ verbose=False,
273
+ ):
274
+ if use_openai_model:
275
+ from langchain.llms import OpenAI
276
+ llm = OpenAI(temperature=0)
277
+ model_name = 'openai'
278
+ streamer = None
279
+ prompt_type = 'plain'
280
+ elif model_name in non_hf_types:
281
+ from gpt4all_llm import get_llm_gpt4all
282
+ llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
283
+ temperature=temperature,
284
+ repetition_penalty=repetition_penalty,
285
+ top_k=top_k,
286
+ top_p=top_p,
287
+ verbose=verbose,
288
+ )
289
+ streamer = None
290
+ prompt_type = 'plain'
291
+ else:
292
+ from transformers import AutoTokenizer, AutoModelForCausalLM
293
+
294
+ if model is None:
295
+ # only used if didn't pass model in
296
+ assert tokenizer is None
297
+ prompt_type = 'human_bot'
298
+ model_name = 'h2oai/h2ogpt-oasst1-512-12b'
299
+ # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
300
+ # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
301
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
302
+ device, torch_dtype, context_class = get_device_dtype()
303
+
304
+ with context_class(device):
305
+ load_8bit = True
306
+ # FIXME: for now not to spread across hetero GPUs
307
+ # device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
308
+ device_map = {"": 0} if device == 'cuda' else "auto"
309
+ model = AutoModelForCausalLM.from_pretrained(model_name,
310
+ device_map=device_map,
311
+ torch_dtype=torch_dtype,
312
+ load_in_8bit=load_8bit)
313
+
314
+ max_max_tokens = tokenizer.model_max_length
315
+ gen_kwargs = dict(max_new_tokens=max_new_tokens,
316
+ return_full_text=True,
317
+ early_stopping=False,
318
+ handle_long_generation='hole')
319
+
320
+ if stream_output:
321
+ skip_prompt = False
322
+ from generate import H2OTextIteratorStreamer
323
+ decoder_kwargs = {}
324
+ streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
325
+ gen_kwargs.update(dict(streamer=streamer))
326
+ else:
327
+ streamer = None
328
+
329
+ from h2oai_pipeline import H2OTextGenerationPipeline
330
+ pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
331
+ prompter=prompter,
332
+ prompt_type=prompt_type,
333
+ prompt_dict=prompt_dict,
334
+ sanitize_bot_response=True,
335
+ chat=False, stream_output=stream_output,
336
+ tokenizer=tokenizer,
337
+ max_input_tokens=max_max_tokens - max_new_tokens,
338
+ **gen_kwargs)
339
+ # pipe.task = "text-generation"
340
+ # below makes it listen only to our prompt removal,
341
+ # not built in prompt removal that is less general and not specific for our model
342
+ pipe.task = "text2text-generation"
343
+
344
+ from langchain.llms import HuggingFacePipeline
345
+ llm = HuggingFacePipeline(pipeline=pipe)
346
+ return llm, model_name, streamer, prompt_type
347
+
348
+
349
+ def get_device_dtype():
350
+ # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
351
+ import torch
352
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
353
+ device = 'cpu' if n_gpus == 0 else 'cuda'
354
+ # from utils import NullContext
355
+ # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class
356
+ context_class = torch.device
357
+ torch_dtype = torch.float16 if device == 'cuda' else torch.float32
358
+ return device, torch_dtype, context_class
359
+
360
+
361
+ def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
362
+ """
363
+ Get wikipedia data from online
364
+ :param title:
365
+ :param first_paragraph_only:
366
+ :param text_limit:
367
+ :param take_head:
368
+ :return:
369
+ """
370
+ filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head)
371
+ url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}"
372
+ if first_paragraph_only:
373
+ url += "&exintro=1"
374
+ import json
375
+ if not os.path.isfile(filename):
376
+ data = requests.get(url).json()
377
+ json.dump(data, open(filename, 'wt'))
378
+ else:
379
+ data = json.load(open(filename, "rt"))
380
+ page_content = list(data["query"]["pages"].values())[0]["extract"]
381
+ if take_head is not None and text_limit is not None:
382
+ page_content = page_content[:text_limit] if take_head else page_content[:-text_limit]
383
+ title_url = str(title).replace(' ', '_')
384
+ return Document(
385
+ page_content=page_content,
386
+ metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"},
387
+ )
388
+
389
+
390
+ def get_wiki_sources(first_para=True, text_limit=None):
391
+ """
392
+ Get specific named sources from wikipedia
393
+ :param first_para:
394
+ :param text_limit:
395
+ :return:
396
+ """
397
+ default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux']
398
+ wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources))
399
+ return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources]
400
+
401
+
402
+ def get_github_docs(repo_owner, repo_name):
403
+ """
404
+ Access github from specific repo
405
+ :param repo_owner:
406
+ :param repo_name:
407
+ :return:
408
+ """
409
+ with tempfile.TemporaryDirectory() as d:
410
+ subprocess.check_call(
411
+ f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .",
412
+ cwd=d,
413
+ shell=True,
414
+ )
415
+ git_sha = (
416
+ subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d)
417
+ .decode("utf-8")
418
+ .strip()
419
+ )
420
+ repo_path = pathlib.Path(d)
421
+ markdown_files = list(repo_path.glob("*/*.md")) + list(
422
+ repo_path.glob("*/*.mdx")
423
+ )
424
+ for markdown_file in markdown_files:
425
+ with open(markdown_file, "r") as f:
426
+ relative_path = markdown_file.relative_to(repo_path)
427
+ github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}"
428
+ yield Document(page_content=f.read(), metadata={"source": github_url})
429
+
430
+
431
+ def get_dai_pickle(dest="."):
432
+ from huggingface_hub import hf_hub_download
433
+ # True for case when locally already logged in with correct token, so don't have to set key
434
+ token = os.getenv('HUGGINGFACE_API_TOKEN', True)
435
+ path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset')
436
+ shutil.copy(path_to_zip_file, dest)
437
+
438
+
439
+ def get_dai_docs(from_hf=False, get_pickle=True):
440
+ """
441
+ Consume DAI documentation, or consume from public pickle
442
+ :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain
443
+ :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF
444
+ :return:
445
+ """
446
+ import pickle
447
+
448
+ if get_pickle:
449
+ get_dai_pickle()
450
+
451
+ dai_store = 'dai_docs.pickle'
452
+ dst = "working_dir_docs"
453
+ if not os.path.isfile(dai_store):
454
+ from create_data import setup_dai_docs
455
+ dst = setup_dai_docs(dst=dst, from_hf=from_hf)
456
+
457
+ import glob
458
+ files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
459
+
460
+ basedir = os.path.abspath(os.getcwd())
461
+ from create_data import rst_to_outputs
462
+ new_outputs = rst_to_outputs(files)
463
+ os.chdir(basedir)
464
+
465
+ pickle.dump(new_outputs, open(dai_store, 'wb'))
466
+ else:
467
+ new_outputs = pickle.load(open(dai_store, 'rb'))
468
+
469
+ sources = []
470
+ for line, file in new_outputs:
471
+ # gradio requires any linked file to be with app.py
472
+ sym_src = os.path.abspath(os.path.join(dst, file))
473
+ sym_dst = os.path.abspath(os.path.join(os.getcwd(), file))
474
+ if os.path.lexists(sym_dst):
475
+ os.remove(sym_dst)
476
+ os.symlink(sym_src, sym_dst)
477
+ itm = Document(page_content=line, metadata={"source": file})
478
+ # NOTE: yield has issues when going into db, loses metadata
479
+ # yield itm
480
+ sources.append(itm)
481
+ return sources
482
+
483
+
484
+ import distutils.spawn
485
+
486
+ have_tesseract = distutils.spawn.find_executable("tesseract")
487
+ have_libreoffice = distutils.spawn.find_executable("libreoffice")
488
+
489
+ import pkg_resources
490
+
491
+ try:
492
+ assert pkg_resources.get_distribution('arxiv') is not None
493
+ assert pkg_resources.get_distribution('pymupdf') is not None
494
+ have_arxiv = True
495
+ except (pkg_resources.DistributionNotFound, AssertionError):
496
+ have_arxiv = False
497
+
498
+ try:
499
+ assert pkg_resources.get_distribution('pymupdf') is not None
500
+ have_pymupdf = True
501
+ except (pkg_resources.DistributionNotFound, AssertionError):
502
+ have_pymupdf = False
503
+
504
+ image_types = ["png", "jpg", "jpeg"]
505
+ non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
506
+ "md", "html",
507
+ "enex", "eml", "epub", "odt", "pptx", "ppt",
508
+ "zip", "urls",
509
+ ]
510
+ # "msg", GPL3
511
+
512
+ if have_libreoffice:
513
+ non_image_types.extend(["docx", "doc"])
514
+
515
+ file_types = non_image_types + image_types
516
+
517
+
518
+ def add_meta(docs1, file):
519
+ file_extension = pathlib.Path(file).suffix
520
+ hashid = hash_file(file)
521
+ if not isinstance(docs1, list):
522
+ docs1 = [docs1]
523
+ [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
524
+
525
+
526
+ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
527
+ chunk=True, chunk_size=512,
528
+ is_url=False, is_txt=False,
529
+ enable_captions=True,
530
+ captions_model=None,
531
+ enable_ocr=False, caption_loader=None,
532
+ headsize=50):
533
+ if file is None:
534
+ if fail_any_exception:
535
+ raise RuntimeError("Unexpected None file")
536
+ else:
537
+ return []
538
+ doc1 = [] # in case no support, or disabled support
539
+ if base_path is None and not is_txt and not is_url:
540
+ # then assume want to persist but don't care which path used
541
+ # can't be in base_path
542
+ dir_name = os.path.dirname(file)
543
+ base_name = os.path.basename(file)
544
+ # if from gradio, will have its own temp uuid too, but that's ok
545
+ base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
546
+ base_path = os.path.join(dir_name, base_name)
547
+ if is_url:
548
+ if file.lower().startswith('arxiv:'):
549
+ query = file.lower().split('arxiv:')
550
+ if len(query) == 2 and have_arxiv:
551
+ query = query[1]
552
+ docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load()
553
+ # ensure string, sometimes None
554
+ [[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1]
555
+ query_url = f"https://arxiv.org/abs/{query}"
556
+ [x.metadata.update(
557
+ dict(source=x.metadata.get('entry_id', query_url), query=query_url,
558
+ input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in
559
+ docs1]
560
+ else:
561
+ docs1 = []
562
+ else:
563
+ docs1 = UnstructuredURLLoader(urls=[file]).load()
564
+ [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
565
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
566
+ elif is_txt:
567
+ base_path = "user_paste"
568
+ source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
569
+ makedirs(os.path.dirname(source_file), exist_ok=True)
570
+ with open(source_file, "wt") as f:
571
+ f.write(file)
572
+ metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
573
+ doc1 = Document(page_content=file, metadata=metadata)
574
+ elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
575
+ docs1 = UnstructuredHTMLLoader(file_path=file).load()
576
+ add_meta(docs1, file)
577
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
578
+ elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
579
+ docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
580
+ add_meta(docs1, file)
581
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
582
+ elif file.lower().endswith('.odt'):
583
+ docs1 = UnstructuredODTLoader(file_path=file).load()
584
+ add_meta(docs1, file)
585
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
586
+ elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
587
+ docs1 = UnstructuredPowerPointLoader(file_path=file).load()
588
+ add_meta(docs1, file)
589
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
590
+ elif file.lower().endswith('.txt'):
591
+ # use UnstructuredFileLoader ?
592
+ docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
593
+ # makes just one, but big one
594
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
595
+ add_meta(doc1, file)
596
+ elif file.lower().endswith('.rtf'):
597
+ docs1 = UnstructuredRTFLoader(file).load()
598
+ add_meta(docs1, file)
599
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
600
+ elif file.lower().endswith('.md'):
601
+ docs1 = UnstructuredMarkdownLoader(file).load()
602
+ add_meta(docs1, file)
603
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
604
+ elif file.lower().endswith('.enex'):
605
+ docs1 = EverNoteLoader(file).load()
606
+ add_meta(doc1, file)
607
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
608
+ elif file.lower().endswith('.epub'):
609
+ docs1 = UnstructuredEPubLoader(file).load()
610
+ add_meta(docs1, file)
611
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
612
+ elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
613
+ docs1 = []
614
+ if have_tesseract and enable_ocr:
615
+ # OCR, somewhat works, but not great
616
+ docs1.extend(UnstructuredImageLoader(file).load())
617
+ add_meta(docs1, file)
618
+ if enable_captions:
619
+ # BLIP
620
+ if caption_loader is not None and not isinstance(caption_loader, (str, bool)):
621
+ # assumes didn't fork into this process with joblib, else can deadlock
622
+ caption_loader.set_image_paths([file])
623
+ docs1c = caption_loader.load()
624
+ add_meta(docs1c, file)
625
+ [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c]
626
+ docs1.extend(docs1c)
627
+ else:
628
+ from image_captions import H2OImageCaptionLoader
629
+ caption_loader = H2OImageCaptionLoader(caption_gpu=caption_loader == 'gpu',
630
+ blip_model=captions_model,
631
+ blip_processor=captions_model)
632
+ caption_loader.set_image_paths([file])
633
+ docs1c = caption_loader.load()
634
+ add_meta(docs1c, file)
635
+ [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c]
636
+ docs1.extend(docs1c)
637
+ for doci in docs1:
638
+ doci.metadata['source'] = doci.metadata['image_path']
639
+ doci.metadata['hash'] = hash_file(doci.metadata['source'])
640
+ if docs1:
641
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
642
+ elif file.lower().endswith('.msg'):
643
+ raise RuntimeError("Not supported, GPL3 license")
644
+ # docs1 = OutlookMessageLoader(file).load()
645
+ # docs1[0].metadata['source'] = file
646
+ elif file.lower().endswith('.eml'):
647
+ try:
648
+ docs1 = UnstructuredEmailLoader(file).load()
649
+ add_meta(docs1, file)
650
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
651
+ except ValueError as e:
652
+ if 'text/html content not found in email' in str(e):
653
+ # e.g. plain/text dict key exists, but not
654
+ # doc1 = TextLoader(file, encoding="utf8").load()
655
+ docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
656
+ add_meta(docs1, file)
657
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
658
+ else:
659
+ raise
660
+ # elif file.lower().endswith('.gcsdir'):
661
+ # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
662
+ # elif file.lower().endswith('.gcsfile'):
663
+ # doc1 = GCSFileLoader(project_name, bucket, blob).load()
664
+ elif file.lower().endswith('.rst'):
665
+ with open(file, "r") as f:
666
+ doc1 = Document(page_content=f.read(), metadata={"source": file})
667
+ add_meta(doc1, file)
668
+ elif file.lower().endswith('.pdf'):
669
+ env_gpt4all_file = ".env_gpt4all"
670
+ from dotenv import dotenv_values
671
+ env_kwargs = dotenv_values(env_gpt4all_file)
672
+ pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
673
+ if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
674
+ # GPL, only use if installed
675
+ from langchain.document_loaders import PyMuPDFLoader
676
+ # load() still chunks by pages, but every page has title at start to help
677
+ doc1 = PyMuPDFLoader(file).load()
678
+ else:
679
+ # open-source fallback
680
+ # load() still chunks by pages, but every page has title at start to help
681
+ doc1 = PyPDFLoader(file).load()
682
+ # Some PDFs return nothing or junk from PDFMinerLoader
683
+ add_meta(doc1, file)
684
+ elif file.lower().endswith('.csv'):
685
+ doc1 = CSVLoader(file).load()
686
+ add_meta(doc1, file)
687
+ elif file.lower().endswith('.py'):
688
+ doc1 = PythonLoader(file).load()
689
+ add_meta(doc1, file)
690
+ elif file.lower().endswith('.toml'):
691
+ doc1 = TomlLoader(file).load()
692
+ add_meta(doc1, file)
693
+ elif file.lower().endswith('.urls'):
694
+ with open(file, "r") as f:
695
+ docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
696
+ add_meta(docs1, file)
697
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
698
+ elif file.lower().endswith('.zip'):
699
+ with zipfile.ZipFile(file, 'r') as zip_ref:
700
+ # don't put into temporary path, since want to keep references to docs inside zip
701
+ # so just extract in path where
702
+ zip_ref.extractall(base_path)
703
+ # recurse
704
+ doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception)
705
+ else:
706
+ raise RuntimeError("No file handler for %s" % os.path.basename(file))
707
+
708
+ # allow doc1 to be list or not. If not list, did not chunk yet, so chunk now
709
+ # if list of length one, don't trust and chunk it
710
+ if not isinstance(doc1, list):
711
+ if chunk:
712
+ docs = chunk_sources([doc1], chunk=chunk, chunk_size=chunk_size)
713
+ else:
714
+ docs = [doc1]
715
+ elif isinstance(doc1, list) and len(doc1) == 1:
716
+ if chunk:
717
+ docs = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
718
+ else:
719
+ docs = doc1
720
+ else:
721
+ docs = doc1
722
+
723
+ assert isinstance(docs, list)
724
+ return docs
725
+
726
+
727
+ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
728
+ chunk=True, chunk_size=512,
729
+ is_url=False, is_txt=False,
730
+ enable_captions=True,
731
+ captions_model=None,
732
+ enable_ocr=False, caption_loader=None):
733
+ if verbose:
734
+ if is_url:
735
+ print("Ingesting URL: %s" % file, flush=True)
736
+ elif is_txt:
737
+ print("Ingesting Text: %s" % file, flush=True)
738
+ else:
739
+ print("Ingesting file: %s" % file, flush=True)
740
+ res = None
741
+ try:
742
+ # don't pass base_path=path, would infinitely recurse
743
+ res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
744
+ chunk=chunk, chunk_size=chunk_size,
745
+ is_url=is_url, is_txt=is_txt,
746
+ enable_captions=enable_captions,
747
+ captions_model=captions_model,
748
+ enable_ocr=enable_ocr,
749
+ caption_loader=caption_loader)
750
+ except BaseException as e:
751
+ print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
752
+ if fail_any_exception:
753
+ raise
754
+ else:
755
+ exception_doc = Document(
756
+ page_content='',
757
+ metadata={"source": file, "exception": str(e), "traceback": traceback.format_exc()})
758
+ res = [exception_doc]
759
+ if return_file:
760
+ base_tmp = "temp_path_to_doc1"
761
+ if not os.path.isdir(base_tmp):
762
+ os.makedirs(base_tmp, exist_ok=True)
763
+ filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
764
+ with open(filename, 'wb') as f:
765
+ pickle.dump(res, f)
766
+ return filename
767
+ return res
768
+
769
+
770
+ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1,
771
+ chunk=True, chunk_size=512,
772
+ url=None, text=None,
773
+ enable_captions=True,
774
+ captions_model=None,
775
+ caption_loader=None,
776
+ enable_ocr=False,
777
+ existing_files=[],
778
+ existing_hash_ids={},
779
+ ):
780
+ globs_image_types = []
781
+ globs_non_image_types = []
782
+ if not path_or_paths and not url and not text:
783
+ return []
784
+ elif url:
785
+ globs_non_image_types = [url]
786
+ elif text:
787
+ globs_non_image_types = [text]
788
+ elif isinstance(path_or_paths, str):
789
+ # single path, only consume allowed files
790
+ path = path_or_paths
791
+ # Below globs should match patterns in file_to_doc()
792
+ [globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
793
+ for ftype in image_types]
794
+ [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
795
+ for ftype in non_image_types]
796
+ else:
797
+ # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
798
+ assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths)
799
+ # reform out of allowed types
800
+ globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
801
+ # could do below:
802
+ # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types])
803
+ # But instead, allow fail so can collect unsupported too
804
+ set_globs_image_types = set(globs_image_types)
805
+ globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types])
806
+
807
+ # filter out any files to skip (e.g. if already processed them)
808
+ # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[]
809
+ assert not existing_files, "DEV: assume not using this approach"
810
+ if existing_files:
811
+ set_skip_files = set(existing_files)
812
+ globs_image_types = [x for x in globs_image_types if x not in set_skip_files]
813
+ globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files]
814
+ if existing_hash_ids:
815
+ # assume consistent with add_meta() use of hash_file(file)
816
+ # also assume consistent with get_existing_hash_ids for dict creation
817
+ # assume hashable values
818
+ existing_hash_ids_set = set(existing_hash_ids.items())
819
+ hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items())
820
+ hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items())
821
+ # don't use symmetric diff. If file is gone, ignore and don't remove or something
822
+ # just consider existing files (key) having new hash or not (value)
823
+ new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys())
824
+ new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys())
825
+ globs_image_types = [x for x in globs_image_types if x in new_files_image]
826
+ globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image]
827
+
828
+ # could use generator, but messes up metadata handling in recursive case
829
+ if caption_loader and not isinstance(caption_loader, (bool, str)) and \
830
+ caption_loader.device != 'cpu' or \
831
+ get_device() == 'cuda':
832
+ # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context
833
+ n_jobs_image = 1
834
+ else:
835
+ n_jobs_image = n_jobs
836
+
837
+ return_file = True # local choice
838
+ is_url = url is not None
839
+ is_txt = text is not None
840
+ kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
841
+ return_file=return_file,
842
+ chunk=chunk, chunk_size=chunk_size,
843
+ is_url=is_url,
844
+ is_txt=is_txt,
845
+ enable_captions=enable_captions,
846
+ captions_model=captions_model,
847
+ caption_loader=caption_loader,
848
+ enable_ocr=enable_ocr,
849
+ )
850
+
851
+ if n_jobs != 1 and len(globs_non_image_types) > 1:
852
+ # avoid nesting, e.g. upload 1 zip and then inside many files
853
+ # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
854
+ documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
855
+ delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types
856
+ )
857
+ else:
858
+ documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_non_image_types)]
859
+
860
+ # do images separately since can't fork after cuda in parent, so can't be parallel
861
+ if n_jobs_image != 1 and len(globs_image_types) > 1:
862
+ # avoid nesting, e.g. upload 1 zip and then inside many files
863
+ # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
864
+ image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
865
+ delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types
866
+ )
867
+ else:
868
+ image_documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_image_types)]
869
+
870
+ # add image docs in
871
+ documents += image_documents
872
+
873
+ if return_file:
874
+ # then documents really are files
875
+ files = documents.copy()
876
+ documents = []
877
+ for fil in files:
878
+ with open(fil, 'rb') as f:
879
+ documents.extend(pickle.load(f))
880
+ # remove temp pickle
881
+ os.remove(fil)
882
+ else:
883
+ documents = reduce(concat, documents)
884
+ return documents
885
+
886
+
887
+ def prep_langchain(persist_directory,
888
+ load_db_if_exists,
889
+ db_type, use_openai_embedding, langchain_mode, user_path,
890
+ hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
891
+ """
892
+ do prep first time, involving downloads
893
+ # FIXME: Add github caching then add here
894
+ :return:
895
+ """
896
+ assert langchain_mode not in ['MyData'], "Should not prep scratch data"
897
+
898
+ db_dir_exists = os.path.isdir(persist_directory)
899
+
900
+ if db_dir_exists and user_path is None:
901
+ print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
902
+ db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
903
+ hf_embedding_model)
904
+ else:
905
+ if db_dir_exists and user_path is not None:
906
+ print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
907
+ persist_directory, user_path), flush=True)
908
+ elif not db_dir_exists:
909
+ print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
910
+ db = None
911
+ if langchain_mode in ['All', 'DriverlessAI docs']:
912
+ # FIXME: Could also just use dai_docs.pickle directly and upload that
913
+ get_dai_docs(from_hf=True)
914
+
915
+ if langchain_mode in ['All', 'wiki']:
916
+ get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit'])
917
+
918
+ langchain_kwargs = kwargs_make_db.copy()
919
+ langchain_kwargs.update(locals())
920
+ db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs)
921
+
922
+ return db
923
+
924
+
925
+ import posthog
926
+
927
+ posthog.disabled = True
928
+
929
+
930
+ class FakeConsumer(object):
931
+ def __init__(self, *args, **kwargs):
932
+ pass
933
+
934
+ def run(self):
935
+ pass
936
+
937
+ def pause(self):
938
+ pass
939
+
940
+ def upload(self):
941
+ pass
942
+
943
+ def next(self):
944
+ pass
945
+
946
+ def request(self, batch):
947
+ pass
948
+
949
+
950
+ posthog.Consumer = FakeConsumer
951
+
952
+
953
+ def check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model, langchain_mode):
954
+ changed_db = False
955
+ if load_embed(db) != (use_openai_embedding, hf_embedding_model):
956
+ print("Detected new embedding, updating db: %s" % langchain_mode, flush=True)
957
+ # handle embedding changes
958
+ db_get = db.get()
959
+ sources = [Document(page_content=result[0], metadata=result[1] or {})
960
+ for result in zip(db_get['documents'], db_get['metadatas'])]
961
+ # delete index, has to be redone
962
+ persist_directory = db._persist_directory
963
+ shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak")
964
+ db_type = 'chroma'
965
+ load_db_if_exists = False
966
+ db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
967
+ persist_directory=persist_directory, load_db_if_exists=load_db_if_exists,
968
+ langchain_mode=langchain_mode,
969
+ collection_name=None,
970
+ hf_embedding_model=hf_embedding_model)
971
+ if False:
972
+ # below doesn't work if db already in memory, so have to switch to new db as above
973
+ # upsert does new embedding, but if index already in memory, complains about size mismatch etc.
974
+ client_collection = db._client.get_collection(name=db._collection.name,
975
+ embedding_function=db._collection._embedding_function)
976
+ client_collection.upsert(ids=db_get['ids'], metadatas=db_get['metadatas'], documents=db_get['documents'])
977
+ changed_db = True
978
+ print("Done updating db for new embedding: %s" % langchain_mode, flush=True)
979
+
980
+ return db, changed_db
981
+
982
+
983
+ def get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
984
+ hf_embedding_model, verbose=False, check_embedding=True):
985
+ if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
986
+ os.path.join(persist_directory, 'index')):
987
+ if db is None:
988
+ if verbose:
989
+ print("DO Loading db: %s" % langchain_mode, flush=True)
990
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
991
+ from chromadb.config import Settings
992
+ client_settings = Settings(anonymized_telemetry=False,
993
+ chroma_db_impl="duckdb+parquet",
994
+ persist_directory=persist_directory)
995
+ db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
996
+ collection_name=langchain_mode.replace(' ', '_'),
997
+ client_settings=client_settings)
998
+ if verbose:
999
+ print("DONE Loading db: %s" % langchain_mode, flush=True)
1000
+ else:
1001
+ if verbose:
1002
+ print("USING already-loaded db: %s" % langchain_mode, flush=True)
1003
+ if check_embedding:
1004
+ db_trial, changed_db = check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model,
1005
+ langchain_mode)
1006
+ if changed_db:
1007
+ db = db_trial
1008
+ # only call persist if really changed db, else takes too long for large db
1009
+ db.persist()
1010
+ clear_embedding(db)
1011
+ save_embed(db, use_openai_embedding, hf_embedding_model)
1012
+ return db
1013
+ return None
1014
+
1015
+
1016
+ def clear_embedding(db):
1017
+ # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
1018
+ db._embedding_function.client.cpu()
1019
+ clear_torch_cache()
1020
+
1021
+
1022
+ def make_db(**langchain_kwargs):
1023
+ func_names = list(inspect.signature(_make_db).parameters)
1024
+ missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
1025
+ defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()}
1026
+ for k in missing_kwargs:
1027
+ if k in defaults_db:
1028
+ langchain_kwargs[k] = defaults_db[k]
1029
+ # final check for missing
1030
+ missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
1031
+ assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1032
+ # only keep actual used
1033
+ langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
1034
+ return _make_db(**langchain_kwargs)
1035
+
1036
+
1037
+ def save_embed(db, use_openai_embedding, hf_embedding_model):
1038
+ embed_info_file = os.path.join(db._persist_directory, 'embed_info')
1039
+ with open(embed_info_file, 'wb') as f:
1040
+ pickle.dump((use_openai_embedding, hf_embedding_model), f)
1041
+ return use_openai_embedding, hf_embedding_model
1042
+
1043
+
1044
+ def load_embed(db):
1045
+ embed_info_file = os.path.join(db._persist_directory, 'embed_info')
1046
+ if os.path.isfile(embed_info_file):
1047
+ with open(embed_info_file, 'rb') as f:
1048
+ use_openai_embedding, hf_embedding_model = pickle.load(f)
1049
+ else:
1050
+ # migration, assume defaults
1051
+ use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2"
1052
+ return use_openai_embedding, hf_embedding_model
1053
+
1054
+
1055
+ def get_persist_directory(langchain_mode):
1056
+ return 'db_dir_%s' % langchain_mode # single place, no special names for each case
1057
+
1058
+
1059
+ def _make_db(use_openai_embedding=False,
1060
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1061
+ first_para=False, text_limit=None,
1062
+ chunk=True, chunk_size=512,
1063
+ langchain_mode=None,
1064
+ user_path=None,
1065
+ db_type='faiss',
1066
+ load_db_if_exists=True,
1067
+ db=None,
1068
+ n_jobs=-1,
1069
+ verbose=False):
1070
+ persist_directory = get_persist_directory(langchain_mode)
1071
+ # see if can get persistent chroma db
1072
+ db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
1073
+ hf_embedding_model, verbose=verbose)
1074
+ if db_trial is not None:
1075
+ db = db_trial
1076
+
1077
+ sources = []
1078
+ if not db and langchain_mode not in ['MyData'] or \
1079
+ user_path is not None and \
1080
+ langchain_mode in ['UserData']:
1081
+ # Should not make MyData db this way, why avoided, only upload from UI
1082
+ assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
1083
+ if verbose:
1084
+ if langchain_mode in ['UserData']:
1085
+ if user_path is not None:
1086
+ print("Checking if changed or new sources in %s, and generating sources them" % user_path,
1087
+ flush=True)
1088
+ elif db is None:
1089
+ print("user_path not passed and no db, no sources", flush=True)
1090
+ else:
1091
+ print("user_path not passed, using only existing db, no new sources", flush=True)
1092
+ else:
1093
+ print("Generating %s sources" % langchain_mode, flush=True)
1094
+ if langchain_mode in ['wiki_full', 'All', "'All'"]:
1095
+ from read_wiki_full import get_all_documents
1096
+ small_test = None
1097
+ print("Generating new wiki", flush=True)
1098
+ sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
1099
+ print("Got new wiki", flush=True)
1100
+ if chunk:
1101
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1102
+ print("Chunked new wiki", flush=True)
1103
+ sources.extend(sources1)
1104
+ if langchain_mode in ['wiki', 'All', "'All'"]:
1105
+ sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
1106
+ if chunk:
1107
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1108
+ sources.extend(sources1)
1109
+ if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
1110
+ # sources = get_github_docs("dagster-io", "dagster")
1111
+ sources1 = get_github_docs("h2oai", "h2ogpt")
1112
+ # FIXME: always chunk for now
1113
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1114
+ sources.extend(sources1)
1115
+ if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
1116
+ sources1 = get_dai_docs(from_hf=True)
1117
+ if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
1118
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1119
+ sources.extend(sources1)
1120
+ if langchain_mode in ['All', 'UserData']:
1121
+ if user_path:
1122
+ if db is not None:
1123
+ # NOTE: Ignore file names for now, only go by hash ids
1124
+ # existing_files = get_existing_files(db)
1125
+ existing_files = []
1126
+ existing_hash_ids = get_existing_hash_ids(db)
1127
+ else:
1128
+ # pretend no existing files so won't filter
1129
+ existing_files = []
1130
+ existing_hash_ids = []
1131
+ # chunk internally for speed over multiple docs
1132
+ # FIXME: If first had old Hash=None and switch embeddings,
1133
+ # then re-embed, and then hit here and reload so have hash, and then re-embed.
1134
+ sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1135
+ existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1136
+ new_metadata_sources = set([x.metadata['source'] for x in sources1])
1137
+ if new_metadata_sources:
1138
+ print("Loaded %s new files as sources to add to UserData" % len(new_metadata_sources), flush=True)
1139
+ if verbose:
1140
+ print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
1141
+ sources.extend(sources1)
1142
+ print("Loaded %s sources for potentially adding to UserData" % len(sources), flush=True)
1143
+ else:
1144
+ print("Chose UserData but user_path is empty/None", flush=True)
1145
+ if False and langchain_mode in ['urls', 'All', "'All'"]:
1146
+ # from langchain.document_loaders import UnstructuredURLLoader
1147
+ # loader = UnstructuredURLLoader(urls=urls)
1148
+ urls = ["https://www.birdsongsf.com/who-we-are/"]
1149
+ from langchain.document_loaders import PlaywrightURLLoader
1150
+ loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
1151
+ sources1 = loader.load()
1152
+ sources.extend(sources1)
1153
+ if not sources:
1154
+ if verbose:
1155
+ if db is not None:
1156
+ print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True)
1157
+ else:
1158
+ print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True)
1159
+ return db, 0, []
1160
+ if verbose:
1161
+ if db is not None:
1162
+ print("Generating db", flush=True)
1163
+ else:
1164
+ print("Adding to db", flush=True)
1165
+ if not db:
1166
+ if sources:
1167
+ db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
1168
+ persist_directory=persist_directory, langchain_mode=langchain_mode,
1169
+ hf_embedding_model=hf_embedding_model)
1170
+ if verbose:
1171
+ print("Generated db", flush=True)
1172
+ else:
1173
+ print("Did not generate db since no sources", flush=True)
1174
+ new_sources_metadata = [x.metadata for x in sources]
1175
+ elif user_path is not None and langchain_mode in ['UserData']:
1176
+ print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1177
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
1178
+ use_openai_embedding=use_openai_embedding,
1179
+ hf_embedding_model=hf_embedding_model)
1180
+ print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
1181
+ else:
1182
+ new_sources_metadata = [x.metadata for x in sources]
1183
+
1184
+ return db, len(new_sources_metadata), new_sources_metadata
1185
+
1186
+
1187
+ def get_existing_files(db):
1188
+ collection = db.get()
1189
+ metadata_sources = set([x['source'] for x in collection['metadatas']])
1190
+ return metadata_sources
1191
+
1192
+
1193
+ def get_existing_hash_ids(db):
1194
+ collection = db.get()
1195
+ # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
1196
+ metadata_hash_ids = {x['source']: x.get('hashid') for x in collection['metadatas']}
1197
+ return metadata_hash_ids
1198
+
1199
+
1200
+ source_prefix = "Sources [Score | Link]:"
1201
+ source_postfix = "End Sources<p>"
1202
+
1203
+
1204
+ def run_qa_db(**kwargs):
1205
+ func_names = list(inspect.signature(_run_qa_db).parameters)
1206
+ # hard-coded defaults
1207
+ kwargs['answer_with_sources'] = True
1208
+ kwargs['sanitize_bot_response'] = True
1209
+ kwargs['show_rank'] = False
1210
+ missing_kwargs = [x for x in func_names if x not in kwargs]
1211
+ assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1212
+ # only keep actual used
1213
+ kwargs = {k: v for k, v in kwargs.items() if k in func_names}
1214
+ try:
1215
+ return _run_qa_db(**kwargs)
1216
+ finally:
1217
+ clear_torch_cache()
1218
+
1219
+
1220
+ def _run_qa_db(query=None,
1221
+ use_openai_model=False, use_openai_embedding=False,
1222
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1223
+ user_path=None,
1224
+ detect_user_path_changes_every_query=False,
1225
+ db_type='faiss',
1226
+ model_name=None, model=None, tokenizer=None,
1227
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1228
+ stream_output=False,
1229
+ prompter=None,
1230
+ prompt_type=None,
1231
+ prompt_dict=None,
1232
+ answer_with_sources=True,
1233
+ cut_distanct=1.1,
1234
+ sanitize_bot_response=True,
1235
+ show_rank=False,
1236
+ load_db_if_exists=False,
1237
+ db=None,
1238
+ max_new_tokens=256,
1239
+ temperature=0.1,
1240
+ repetition_penalty=1.0,
1241
+ top_k=40,
1242
+ top_p=0.7,
1243
+ langchain_mode=None,
1244
+ document_choice=[DocumentChoices.All_Relevant.name],
1245
+ n_jobs=-1,
1246
+ verbose=False,
1247
+ cli=False):
1248
+ """
1249
+
1250
+ :param query:
1251
+ :param use_openai_model:
1252
+ :param use_openai_embedding:
1253
+ :param first_para:
1254
+ :param text_limit:
1255
+ :param k:
1256
+ :param chunk:
1257
+ :param chunk_size:
1258
+ :param user_path: user path to glob recursively from
1259
+ :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
1260
+ :param model_name: model name, used to switch behaviors
1261
+ :param model: pre-initialized model, else will make new one
1262
+ :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
1263
+ :param answer_with_sources
1264
+ :return:
1265
+ """
1266
+ assert query is not None
1267
+ assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
1268
+ if prompter is not None:
1269
+ prompt_type = prompter.prompt_type
1270
+ prompt_dict = prompter.prompt_dict
1271
+ if model is not None:
1272
+ assert prompt_type is not None
1273
+ if prompt_type == PromptType.custom.name:
1274
+ assert prompt_dict is not None # should at least be {} or ''
1275
+ else:
1276
+ prompt_dict = ''
1277
+ llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1278
+ model=model, tokenizer=tokenizer,
1279
+ stream_output=stream_output,
1280
+ max_new_tokens=max_new_tokens,
1281
+ temperature=temperature,
1282
+ repetition_penalty=repetition_penalty,
1283
+ top_k=top_k,
1284
+ top_p=top_p,
1285
+ prompt_type=prompt_type,
1286
+ prompt_dict=prompt_dict,
1287
+ prompter=prompter,
1288
+ verbose=verbose,
1289
+ )
1290
+
1291
+ if model_name in non_hf_types:
1292
+ # FIXME: for now, streams to stdout/stderr currently
1293
+ stream_output = False
1294
+
1295
+ use_context = False
1296
+ scores = []
1297
+ chain = None
1298
+
1299
+ if isinstance(document_choice, str):
1300
+ # support string as well
1301
+ document_choice = [document_choice]
1302
+ # get first DocumentChoices as command to use, ignore others
1303
+ doc_choices_set = set([x.name for x in list(DocumentChoices)])
1304
+ cmd = [x for x in document_choice if x in doc_choices_set]
1305
+ cmd = None if len(cmd) == 0 else cmd[0]
1306
+ # now have cmd, filter out for only docs
1307
+ document_choice = [x for x in document_choice if x not in doc_choices_set]
1308
+
1309
+ func_names = list(inspect.signature(get_similarity_chain).parameters)
1310
+ sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1311
+ missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1312
+ assert not missing_kwargs, "Missing: %s" % missing_kwargs
1313
+ docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
1314
+ if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
1315
+ formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1316
+ yield formatted_doc_chunks, ''
1317
+ return
1318
+ if chain is None and model_name not in non_hf_types:
1319
+ # can only return if HF type
1320
+ return
1321
+
1322
+ if stream_output:
1323
+ answer = None
1324
+ assert streamer is not None
1325
+ import queue
1326
+ bucket = queue.Queue()
1327
+ thread = EThread(target=chain, streamer=streamer, bucket=bucket)
1328
+ thread.start()
1329
+ outputs = ""
1330
+ prompt = None # FIXME
1331
+ try:
1332
+ for new_text in streamer:
1333
+ # print("new_text: %s" % new_text, flush=True)
1334
+ if bucket.qsize() > 0 or thread.exc:
1335
+ thread.join()
1336
+ outputs += new_text
1337
+ if prompter: # and False: # FIXME: pipeline can already use prompter
1338
+ output1 = prompter.get_response(outputs, prompt=prompt,
1339
+ sanitize_bot_response=sanitize_bot_response)
1340
+ yield output1, ''
1341
+ else:
1342
+ yield outputs, ''
1343
+ except BaseException:
1344
+ # if any exception, raise that exception if was from thread, first
1345
+ if thread.exc:
1346
+ raise thread.exc
1347
+ raise
1348
+ finally:
1349
+ # in case no exception and didn't join with thread yet, then join
1350
+ if not thread.exc:
1351
+ answer = thread.join()
1352
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
1353
+ if thread.exc:
1354
+ raise thread.exc
1355
+ # FIXME: answer is not string outputs from streamer. How to get actual final output?
1356
+ # answer = outputs
1357
+ else:
1358
+ answer = chain()
1359
+
1360
+ if not use_context:
1361
+ ret = answer['output_text']
1362
+ extra = ''
1363
+ yield ret, extra
1364
+ elif answer is not None:
1365
+ ret, extra = get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=verbose)
1366
+ yield ret, extra
1367
+ return
1368
+
1369
+
1370
+ def get_similarity_chain(query=None,
1371
+ use_openai_model=False, use_openai_embedding=False,
1372
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1373
+ user_path=None,
1374
+ detect_user_path_changes_every_query=False,
1375
+ db_type='faiss',
1376
+ model_name=None,
1377
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1378
+ prompt_type=None,
1379
+ prompt_dict=None,
1380
+ cut_distanct=1.1,
1381
+ load_db_if_exists=False,
1382
+ db=None,
1383
+ langchain_mode=None,
1384
+ document_choice=[DocumentChoices.All_Relevant.name],
1385
+ n_jobs=-1,
1386
+ # beyond run_db_query:
1387
+ llm=None,
1388
+ verbose=False,
1389
+ cmd=None,
1390
+ ):
1391
+ # determine whether use of context out of docs is planned
1392
+ if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1393
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
1394
+ use_context = False
1395
+ else:
1396
+ use_context = True
1397
+ else:
1398
+ use_context = True
1399
+
1400
+ # https://github.com/hwchase17/langchain/issues/1946
1401
+ # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
1402
+ # Chroma collection MyData contains fewer than 4 elements.
1403
+ # type logger error
1404
+ k_db = 1000 if db_type == 'chroma' else top_k_docs # top_k_docs=100 works ok too for
1405
+
1406
+ # FIXME: For All just go over all dbs instead of a separate db for All
1407
+ if not detect_user_path_changes_every_query and db is not None:
1408
+ # avoid looking at user_path during similarity search db handling,
1409
+ # if already have db and not updating from user_path every query
1410
+ # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
1411
+ user_path = None
1412
+ db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
1413
+ hf_embedding_model=hf_embedding_model,
1414
+ first_para=first_para, text_limit=text_limit,
1415
+ chunk=chunk,
1416
+ chunk_size=chunk_size,
1417
+ langchain_mode=langchain_mode,
1418
+ user_path=user_path,
1419
+ db_type=db_type,
1420
+ load_db_if_exists=load_db_if_exists,
1421
+ db=db,
1422
+ n_jobs=n_jobs,
1423
+ verbose=verbose)
1424
+
1425
+ if db and use_context:
1426
+ if not isinstance(db, Chroma):
1427
+ # only chroma supports filtering
1428
+ filter_kwargs = {}
1429
+ else:
1430
+ # if here then some cmd + documents selected or just documents selected
1431
+ if len(document_choice) >= 2:
1432
+ or_filter = [{"source": {"$eq": x}} for x in document_choice]
1433
+ filter_kwargs = dict(filter={"$or": or_filter})
1434
+ elif len(document_choice) == 1:
1435
+ # degenerate UX bug in chroma
1436
+ one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
1437
+ filter_kwargs = dict(filter=one_filter)
1438
+ else:
1439
+ # shouldn't reach
1440
+ filter_kwargs = {}
1441
+ if cmd == DocumentChoices.Just_LLM.name:
1442
+ docs = []
1443
+ scores = []
1444
+ elif cmd == DocumentChoices.Only_All_Sources.name:
1445
+ if isinstance(db, Chroma):
1446
+ db_get = db._collection.get(where=filter_kwargs.get('filter'))
1447
+ else:
1448
+ db_get = db.get()
1449
+ # similar to langchain's chroma's _results_to_docs_and_scores
1450
+ docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
1451
+ for result in zip(db_get['documents'], db_get['metadatas'])][:top_k_docs]
1452
+ docs = [x[0] for x in docs_with_score]
1453
+ scores = [x[1] for x in docs_with_score]
1454
+ else:
1455
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
1456
+ # cut off so no high distance docs/sources considered
1457
+ docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
1458
+ scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
1459
+ if len(scores) > 0 and verbose:
1460
+ print("Distance: min: %s max: %s mean: %s median: %s" %
1461
+ (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
1462
+ else:
1463
+ docs = []
1464
+ scores = []
1465
+
1466
+ if not docs and use_context and model_name not in non_hf_types:
1467
+ # if HF type and have no docs, can bail out
1468
+ return docs, None, [], False
1469
+
1470
+ if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
1471
+ # no LLM use
1472
+ return docs, None, [], False
1473
+
1474
+ common_words_file = "data/NGSL_1.2_stats.csv.zip"
1475
+ if os.path.isfile(common_words_file):
1476
+ df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
1477
+ import string
1478
+ reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
1479
+ reduced_query_words = reduced_query.split(' ')
1480
+ set_common = set(df['Lemma'].values.tolist())
1481
+ num_common = len([x.lower() in set_common for x in reduced_query_words])
1482
+ frac_common = num_common / len(reduced_query) if reduced_query else 0
1483
+ # FIXME: report to user bad query that uses too many common words
1484
+ if verbose:
1485
+ print("frac_common: %s" % frac_common, flush=True)
1486
+
1487
+ if len(docs) == 0:
1488
+ # avoid context == in prompt then
1489
+ use_context = False
1490
+
1491
+ if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1492
+ # instruct-like, rather than few-shot prompt_type='plain' as default
1493
+ # but then sources confuse the model with how inserted among rest of text, so avoid
1494
+ prefix = ""
1495
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
1496
+ template = """%s{context}{question}""" % prefix
1497
+ else:
1498
+ template = """%s
1499
+ ==
1500
+ {context}
1501
+ ==
1502
+ {question}""" % prefix
1503
+ prompt = PromptTemplate(
1504
+ # input_variables=["summaries", "question"],
1505
+ input_variables=["context", "question"],
1506
+ template=template,
1507
+ )
1508
+ chain = load_qa_chain(llm, prompt=prompt)
1509
+ else:
1510
+ chain = load_qa_with_sources_chain(llm)
1511
+
1512
+ if not use_context:
1513
+ chain_kwargs = dict(input_documents=[], question=query)
1514
+ else:
1515
+ chain_kwargs = dict(input_documents=docs, question=query)
1516
+
1517
+ target = wrapped_partial(chain, chain_kwargs)
1518
+ return docs, target, scores, use_context
1519
+
1520
+
1521
+ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
1522
+ if verbose:
1523
+ print("query: %s" % query, flush=True)
1524
+ print("answer: %s" % answer['output_text'], flush=True)
1525
+
1526
+ if len(answer['input_documents']) == 0:
1527
+ extra = ''
1528
+ ret = answer['output_text'] + extra
1529
+ return ret, extra
1530
+
1531
+ # link
1532
+ answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
1533
+ zip(scores, answer['input_documents'])]
1534
+ answer_sources_dict = defaultdict(list)
1535
+ [answer_sources_dict[url].append(score) for score, url in answer_sources]
1536
+ answers_dict = {}
1537
+ for url, scores_url in answer_sources_dict.items():
1538
+ answers_dict[url] = np.max(scores_url)
1539
+ answer_sources = [(score, url) for url, score in answers_dict.items()]
1540
+ answer_sources.sort(key=lambda x: x[0], reverse=True)
1541
+ if show_rank:
1542
+ # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
1543
+ # sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources)
1544
+ answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
1545
+ sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources)
1546
+ else:
1547
+ answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources]
1548
+ sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
1549
+ sorted_sources_urls += f"</ul></p>{source_postfix}"
1550
+
1551
+ if not answer['output_text'].endswith('\n'):
1552
+ answer['output_text'] += '\n'
1553
+
1554
+ if answer_with_sources:
1555
+ extra = '\n' + sorted_sources_urls
1556
+ else:
1557
+ extra = ''
1558
+ ret = answer['output_text'] + extra
1559
+ return ret, extra
1560
+
1561
+
1562
+ def chunk_sources(sources, chunk=True, chunk_size=512):
1563
+ if not chunk:
1564
+ return sources
1565
+ source_chunks = []
1566
+ # Below for known separator
1567
+ # splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0)
1568
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
1569
+ for source in sources:
1570
+ # print(source.metadata['source'], flush=True)
1571
+ for chunky in splitter.split_text(source.page_content):
1572
+ source_chunks.append(Document(page_content=chunky, metadata=source.metadata))
1573
+ return source_chunks
1574
+
1575
+
1576
+ def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'):
1577
+ from huggingface_hub import hf_hub_download
1578
+ # True for case when locally already logged in with correct token, so don't have to set key
1579
+ token = os.getenv('HUGGINGFACE_API_TOKEN', True)
1580
+ path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
1581
+ import zipfile
1582
+ with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
1583
+ persist_directory = os.path.dirname(zip_ref.namelist()[0])
1584
+ remove(persist_directory)
1585
+ zip_ref.extractall(dest)
1586
+ return path_to_zip_file
1587
+
1588
+
1589
+ # Note dir has space in some cases, while zip does not
1590
+ some_db_zips = [['db_dir_DriverlessAI_docs.zip', 'db_dir_DriverlessAI docs', 'CC-BY-NC license'],
1591
+ ['db_dir_UserData.zip', 'db_dir_UserData', 'CC-BY license for ArXiv'],
1592
+ ['db_dir_github_h2oGPT.zip', 'db_dir_github h2oGPT', 'ApacheV2 license'],
1593
+ ['db_dir_wiki.zip', 'db_dir_wiki', 'CC-BY-SA Wikipedia license'],
1594
+ # ['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'],
1595
+ ]
1596
+
1597
+ all_db_zips = some_db_zips + \
1598
+ [['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'],
1599
+ ]
1600
+
1601
+
1602
+ def get_some_dbs_from_hf(dest='.', db_zips=None):
1603
+ if db_zips is None:
1604
+ db_zips = some_db_zips
1605
+ for db_dir, dir_expected, license1 in db_zips:
1606
+ path_to_zip_file = get_db_from_hf(dest=dest, db_dir=db_dir)
1607
+ assert os.path.isfile(path_to_zip_file), "Missing zip in %s" % path_to_zip_file
1608
+ if dir_expected:
1609
+ assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
1610
+ assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
1611
+
1612
+ def _create_local_weaviate_client():
1613
+ WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
1614
+ WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
1615
+ WEAVIATE_PASSWORD = os.getenv('WEAVIATE_PASSWORD')
1616
+ WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access")
1617
+
1618
+ resource_owner_config = None
1619
+ if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None:
1620
+ resource_owner_config = weaviate.AuthClientPassword(
1621
+ username=WEAVIATE_USERNAME,
1622
+ password=WEAVIATE_PASSWORD,
1623
+ scope=WEAVIATE_SCOPE
1624
+ )
1625
+
1626
+ try:
1627
+ client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config)
1628
+ except Exception as e:
1629
+ print(f"Failed to create Weaviate client: {e}")
1630
+ return None
1631
+
1632
+ if __name__ == '__main__':
1633
+ pass
gradio_runner.py DELETED
@@ -1 +0,0 @@
1
- ../../gradio_runner.py
 
 
gradio_runner.py ADDED
The diff for this file is too large to render. See raw diff
 
gradio_themes.py DELETED
@@ -1 +0,0 @@
1
- ../../gradio_themes.py
 
 
gradio_themes.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable
4
+
5
+ from gradio.themes.soft import Soft
6
+ from gradio.themes import Color
7
+ from gradio.themes.utils import colors, sizes, fonts
8
+
9
+ h2o_yellow = Color(
10
+ name="yellow",
11
+ c50="#fffef2",
12
+ c100="#fff9e6",
13
+ c200="#ffecb3",
14
+ c300="#ffe28c",
15
+ c400="#ffd659",
16
+ c500="#fec925",
17
+ c600="#e6ac00",
18
+ c700="#bf8f00",
19
+ c800="#a67c00",
20
+ c900="#664d00",
21
+ c950="#403000",
22
+ )
23
+ h2o_gray = Color(
24
+ name="gray",
25
+ c50="#f8f8f8",
26
+ c100="#e5e5e5",
27
+ c200="#cccccc",
28
+ c300="#b2b2b2",
29
+ c400="#999999",
30
+ c500="#7f7f7f",
31
+ c600="#666666",
32
+ c700="#4c4c4c",
33
+ c800="#333333",
34
+ c900="#191919",
35
+ c950="#0d0d0d",
36
+ )
37
+
38
+
39
+ class H2oTheme(Soft):
40
+ def __init__(
41
+ self,
42
+ *,
43
+ primary_hue: colors.Color | str = h2o_yellow,
44
+ secondary_hue: colors.Color | str = h2o_yellow,
45
+ neutral_hue: colors.Color | str = h2o_gray,
46
+ spacing_size: sizes.Size | str = sizes.spacing_md,
47
+ radius_size: sizes.Size | str = sizes.radius_md,
48
+ text_size: sizes.Size | str = sizes.text_lg,
49
+ font: fonts.Font
50
+ | str
51
+ | Iterable[fonts.Font | str] = (
52
+ fonts.GoogleFont("Montserrat"),
53
+ "ui-sans-serif",
54
+ "system-ui",
55
+ "sans-serif",
56
+ ),
57
+ font_mono: fonts.Font
58
+ | str
59
+ | Iterable[fonts.Font | str] = (
60
+ fonts.GoogleFont("IBM Plex Mono"),
61
+ "ui-monospace",
62
+ "Consolas",
63
+ "monospace",
64
+ ),
65
+ ):
66
+ super().__init__(
67
+ primary_hue=primary_hue,
68
+ secondary_hue=secondary_hue,
69
+ neutral_hue=neutral_hue,
70
+ spacing_size=spacing_size,
71
+ radius_size=radius_size,
72
+ text_size=text_size,
73
+ font=font,
74
+ font_mono=font_mono,
75
+ )
76
+ super().set(
77
+ link_text_color="#3344DD",
78
+ link_text_color_hover="#3344DD",
79
+ link_text_color_visited="#3344DD",
80
+ link_text_color_dark="#74abff",
81
+ link_text_color_hover_dark="#a3c8ff",
82
+ link_text_color_active_dark="#a3c8ff",
83
+ link_text_color_visited_dark="#74abff",
84
+ button_primary_text_color="*neutral_950",
85
+ button_primary_text_color_dark="*neutral_950",
86
+ button_primary_background_fill="*primary_500",
87
+ button_primary_background_fill_dark="*primary_500",
88
+ block_label_background_fill="*primary_500",
89
+ block_label_background_fill_dark="*primary_500",
90
+ block_label_text_color="*neutral_950",
91
+ block_label_text_color_dark="*neutral_950",
92
+ block_title_text_color="*neutral_950",
93
+ block_title_text_color_dark="*neutral_950",
94
+ block_background_fill_dark="*neutral_950",
95
+ body_background_fill="*neutral_50",
96
+ body_background_fill_dark="*neutral_900",
97
+ background_fill_primary_dark="*block_background_fill",
98
+ block_radius="0 0 8px 8px",
99
+ checkbox_label_text_color_selected_dark='#000000',
100
+ )
101
+
102
+
103
+ class SoftTheme(Soft):
104
+ def __init__(
105
+ self,
106
+ *,
107
+ primary_hue: colors.Color | str = colors.indigo,
108
+ secondary_hue: colors.Color | str = colors.indigo,
109
+ neutral_hue: colors.Color | str = colors.gray,
110
+ spacing_size: sizes.Size | str = sizes.spacing_md,
111
+ radius_size: sizes.Size | str = sizes.radius_md,
112
+ text_size: sizes.Size | str = sizes.text_md,
113
+ font: fonts.Font
114
+ | str
115
+ | Iterable[fonts.Font | str] = (
116
+ fonts.GoogleFont("Montserrat"),
117
+ "ui-sans-serif",
118
+ "system-ui",
119
+ "sans-serif",
120
+ ),
121
+ font_mono: fonts.Font
122
+ | str
123
+ | Iterable[fonts.Font | str] = (
124
+ fonts.GoogleFont("IBM Plex Mono"),
125
+ "ui-monospace",
126
+ "Consolas",
127
+ "monospace",
128
+ ),
129
+ ):
130
+ super().__init__(
131
+ primary_hue=primary_hue,
132
+ secondary_hue=secondary_hue,
133
+ neutral_hue=neutral_hue,
134
+ spacing_size=spacing_size,
135
+ radius_size=radius_size,
136
+ text_size=text_size,
137
+ font=font,
138
+ font_mono=font_mono,
139
+ )
140
+
141
+
142
+ h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
143
+ ' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:' \
144
+ '#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" ' \
145
+ 'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.' \
146
+ '47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.' \
147
+ '82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10' \
148
+ '.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"' \
149
+ '/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.' \
150
+ '76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69' \
151
+ ',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.' \
152
+ '85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.' \
153
+ '69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30.' \
154
+ '62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8.' \
155
+ '62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-' \
156
+ '12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"' \
157
+ ' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,' \
158
+ '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
159
+
160
+
161
+ def get_h2o_title(title):
162
+ return f"""<div style="display:flex; justify-content:center; margin-bottom:30px;">
163
+ <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
164
+ <h1 style="line-height:60px">{title}</h1>
165
+ </div>
166
+ <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
167
+ <img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png></img>
168
+ </div>
169
+ """
170
+
171
+
172
+ def get_simple_title(title):
173
+ return f"""<h1 align="center"> {title}</h1>"""
174
+
175
+
176
+ def get_dark_js():
177
+ return """() => {
178
+ if (document.querySelectorAll('.dark').length) {
179
+ document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
180
+ } else {
181
+ document.querySelector('body').classList.add('dark');
182
+ }
183
+ }"""
gradio_ui DELETED
@@ -1 +0,0 @@
1
- ../../gradio_ui
 
 
gradio_ui/css.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_css(kwargs) -> str:
2
+ if kwargs['h2ocolors']:
3
+ css_code = """footer {visibility: hidden;}
4
+ body{background:linear-gradient(#f5f5f5,#e5e5e5);}
5
+ body.dark{background:linear-gradient(#000000,#0d0d0d);}
6
+ """
7
+ else:
8
+ css_code = """footer {visibility: hidden}"""
9
+
10
+ css_code += make_css_base()
11
+ return css_code
12
+
13
+ def make_css_base() -> str:
14
+ return """
15
+ @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
16
+
17
+ body.dark{#warning {background-color: #555555};}
18
+
19
+ #small_btn {
20
+ margin: 0.6em 0em 0.55em 0;
21
+ max-width: 20em;
22
+ min-width: 5em !important;
23
+ height: 5em;
24
+ font-size: 14px !important;
25
+ }
26
+
27
+ #prompt-form {
28
+ border: 1px solid var(--primary-500) !important;
29
+ }
30
+
31
+ #prompt-form.block {
32
+ border-radius: var(--block-radius) !important;
33
+ }
34
+
35
+ #prompt-form textarea {
36
+ border: 1px solid rgb(209, 213, 219);
37
+ }
38
+
39
+ #prompt-form label > div {
40
+ margin-top: 4px;
41
+ }
42
+
43
+ button.primary:hover {
44
+ background-color: var(--primary-600) !important;
45
+ transition: .2s;
46
+ }
47
+
48
+ #prompt-form-area {
49
+ margin-bottom: 2.5rem;
50
+ }
51
+ """
gradio_ui/prompt_form.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ def make_prompt_form(kwargs):
5
+ if kwargs['input_lines'] > 1:
6
+ instruction_label = "press Shift-Enter or click Submit to send message, press Enter for multiple input lines"
7
+ else:
8
+ instruction_label = "press Enter or click Submit to send message, press Shift-Enter for more lines"
9
+
10
+ with gr.Row(elem_id='prompt-form-area'):
11
+ with gr.Column(scale=50):
12
+ instruction = gr.Textbox(
13
+ lines=kwargs['input_lines'],
14
+ label='Ask anything',
15
+ placeholder=kwargs['placeholder_instruction'],
16
+ info=instruction_label,
17
+ elem_id='prompt-form'
18
+ )
19
+ instruction.style(container=True)
20
+ with gr.Row():
21
+ submit = gr.Button(value='Submit', variant='primary').style(full_width=False, size='sm')
22
+ stop_btn = gr.Button(value="Stop", variant='secondary').style(full_width=False, size='sm')
23
+
24
+ return instruction, submit, stop_btn
h2o-logo.svg DELETED
h2o-logo.svg ADDED
h2oai_pipeline.py DELETED
@@ -1 +0,0 @@
1
- ../../h2oai_pipeline.py
 
 
h2oai_pipeline.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from transformers import TextGenerationPipeline
4
+ from transformers.pipelines.text_generation import ReturnType
5
+
6
+ from stopping import get_stopping
7
+ from prompter import Prompter, PromptType
8
+
9
+
10
+ class H2OTextGenerationPipeline(TextGenerationPipeline):
11
+ def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
+ sanitize_bot_response=True,
13
+ use_prompter=True, prompter=None,
14
+ prompt_type=None, prompt_dict=None,
15
+ max_input_tokens=2048 - 256, **kwargs):
16
+ """
17
+ HF-like pipeline, but handle instruction prompting and stopping (for some models)
18
+ :param args:
19
+ :param debug:
20
+ :param chat:
21
+ :param stream_output:
22
+ :param sanitize_bot_response:
23
+ :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
24
+ :param prompter: prompter, can pass if have already
25
+ :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
26
+ If use_prompter, then will make prompter and use it.
27
+ :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
28
+ :param max_input_tokens:
29
+ :param kwargs:
30
+ """
31
+ super().__init__(*args, **kwargs)
32
+ self.prompt_text = None
33
+ self.use_prompter = use_prompter
34
+ self.prompt_type = prompt_type
35
+ self.prompt_dict = prompt_dict
36
+ self.prompter = prompter
37
+ if self.use_prompter:
38
+ if self.prompter is not None:
39
+ assert self.prompter.prompt_type is not None
40
+ else:
41
+ self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat,
42
+ stream_output=stream_output)
43
+ self.human = self.prompter.humanstr
44
+ self.bot = self.prompter.botstr
45
+ self.can_stop = True
46
+ else:
47
+ self.prompter = None
48
+ self.human = None
49
+ self.bot = None
50
+ self.can_stop = False
51
+ self.sanitize_bot_response = sanitize_bot_response
52
+ self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
53
+
54
+ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
55
+ if hasattr(self.tokenizer, 'model_max_length'):
56
+ # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
57
+ model_max_length = self.tokenizer.model_max_length
58
+ else:
59
+ # unknown
60
+ model_max_length = None
61
+
62
+ verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
63
+ if model_max_length is not None:
64
+ num_prompt_tokens = None
65
+ # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
66
+ # For https://github.com/h2oai/h2ogpt/issues/192
67
+ for trial in range(0, 3):
68
+ prompt_tokens = self.tokenizer(prompt_text)['input_ids']
69
+ num_prompt_tokens = len(prompt_tokens)
70
+ if num_prompt_tokens > model_max_length:
71
+ # conservative by using int()
72
+ chars_per_token = int(len(prompt_text) / num_prompt_tokens)
73
+ prompt_text = prompt_text[-model_max_length * chars_per_token:]
74
+ if verbose:
75
+ print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
76
+ num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
77
+ else:
78
+ if verbose:
79
+ print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
80
+ break
81
+
82
+ # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
83
+ #
84
+ if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
85
+ # then give room for prompt
86
+ fudge = 20
87
+ else:
88
+ fudge = 0
89
+ assert num_prompt_tokens is not None
90
+ max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
91
+ model_max_length - (num_prompt_tokens + fudge)))
92
+ if max_new_tokens < generate_kwargs['max_new_tokens']:
93
+ if verbose:
94
+ print("Reduced max_new_tokens from %s -> %s" % (generate_kwargs['max_new_tokens'], max_new_tokens))
95
+ generate_kwargs['max_new_tokens'] = max_new_tokens
96
+
97
+ data_point = dict(context='', instruction=prompt_text, input='')
98
+ if self.prompter is not None:
99
+ prompt_text = self.prompter.generate_prompt(data_point)
100
+ self.prompt_text = prompt_text
101
+ if handle_long_generation is None:
102
+ # forces truncation of inputs to avoid critical failure
103
+ handle_long_generation = 'hole'
104
+ return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
105
+ **generate_kwargs)
106
+
107
+ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
108
+ records = super().postprocess(model_outputs, return_type=return_type,
109
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces)
110
+ for rec in records:
111
+ if self.use_prompter:
112
+ outputs = rec['generated_text']
113
+ outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
114
+ sanitize_bot_response=self.sanitize_bot_response)
115
+ elif self.bot and self.human:
116
+ outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
117
+ else:
118
+ outputs = rec['generated_text']
119
+ rec['generated_text'] = outputs
120
+ return records
121
+
122
+ def _forward(self, model_inputs, **generate_kwargs):
123
+ if self.can_stop:
124
+ stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
125
+ self.tokenizer, self.device,
126
+ human=self.human, bot=self.bot)
127
+ generate_kwargs['stopping_criteria'] = stopping_criteria
128
+ # return super()._forward(model_inputs, **generate_kwargs)
129
+ return self.__forward(model_inputs, **generate_kwargs)
130
+
131
+ # FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
132
+ # FIXME: https://github.com/h2oai/h2ogpt/issues/172
133
+ def __forward(self, model_inputs, **generate_kwargs):
134
+ input_ids = model_inputs["input_ids"]
135
+ attention_mask = model_inputs.get("attention_mask", None)
136
+ # Allow empty prompts
137
+ if input_ids.shape[1] == 0:
138
+ input_ids = None
139
+ attention_mask = None
140
+ in_b = 1
141
+ else:
142
+ in_b = input_ids.shape[0]
143
+ prompt_text = model_inputs.pop("prompt_text")
144
+
145
+ ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
146
+ ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
147
+ # generate_kwargs = copy.deepcopy(generate_kwargs)
148
+ prefix_length = generate_kwargs.pop("prefix_length", 0)
149
+ if prefix_length > 0:
150
+ has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
151
+ "generation_config" in generate_kwargs
152
+ and generate_kwargs["generation_config"].max_new_tokens is not None
153
+ )
154
+ if not has_max_new_tokens:
155
+ generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
156
+ generate_kwargs["max_length"] += prefix_length
157
+ has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
158
+ "generation_config" in generate_kwargs
159
+ and generate_kwargs["generation_config"].min_new_tokens is not None
160
+ )
161
+ if not has_min_new_tokens and "min_length" in generate_kwargs:
162
+ generate_kwargs["min_length"] += prefix_length
163
+
164
+ # BS x SL
165
+ generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
166
+ out_b = generated_sequence.shape[0]
167
+ if self.framework == "pt":
168
+ generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
169
+ elif self.framework == "tf":
170
+ from transformers import is_tf_available
171
+ if is_tf_available():
172
+ import tensorflow as tf
173
+ generated_sequence = tf.reshape(generated_sequence,
174
+ (in_b, out_b // in_b, *generated_sequence.shape[1:]))
175
+ else:
176
+ raise ValueError("TF not avaialble.")
177
+ return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
loaders.py DELETED
@@ -1 +0,0 @@
1
- ../../loaders.py
 
 
loaders.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_loaders(llama_type, model_name, reward_type):
2
+ # NOTE: Some models need specific new prompt_type
3
+ # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
4
+ if llama_type:
5
+ from transformers import LlamaForCausalLM, LlamaTokenizer
6
+ model_loader = LlamaForCausalLM
7
+ tokenizer_loader = LlamaTokenizer
8
+ elif 'distilgpt2' in model_name.lower():
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ return AutoModelForCausalLM, AutoTokenizer
11
+ elif 'gpt2' in model_name.lower():
12
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
13
+ return GPT2LMHeadModel, GPT2Tokenizer
14
+ elif 'mbart-' in model_name.lower():
15
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
16
+ return MBartForConditionalGeneration, MBart50TokenizerFast
17
+ elif 't5' == model_name.lower() or \
18
+ 't5-' in model_name.lower() or \
19
+ 'flan-' in model_name.lower():
20
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
21
+ return T5ForConditionalGeneration, AutoTokenizer
22
+ elif 'bigbird' in model_name:
23
+ from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
24
+ return BigBirdPegasusForConditionalGeneration, AutoTokenizer
25
+ elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
26
+ from transformers import pipeline
27
+ return pipeline, "summarization"
28
+ elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
29
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
30
+ return AutoModelForSequenceClassification, AutoTokenizer
31
+ else:
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+ model_loader = AutoModelForCausalLM
34
+ tokenizer_loader = AutoTokenizer
35
+ return model_loader, tokenizer_loader
36
+
37
+
38
+ def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
39
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
40
+ local_files_only=local_files_only,
41
+ resume_download=resume_download,
42
+ use_auth_token=use_auth_token)
43
+
44
+ tokenizer.pad_token_id = 0 # different from the eos token
45
+ # when generating, we will use the logits of right-most token to predict the next token
46
+ # so the padding should be on the left,
47
+ # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
48
+ tokenizer.padding_side = "left" # Allow batched inference
49
+
50
+ return tokenizer
prompter.py DELETED
@@ -1 +0,0 @@
1
- ../../prompter.py
 
 
prompter.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import time
3
+ from enums import PromptType # also supports imports from this file from other files
4
+
5
+ non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
6
+
7
+
8
+ prompt_type_to_model_name = {
9
+ 'plain': [
10
+ 'EleutherAI/gpt-j-6B',
11
+ 'EleutherAI/pythia-6.9b',
12
+ 'EleutherAI/pythia-12b',
13
+ 'EleutherAI/pythia-12b-deduped',
14
+ 'EleutherAI/gpt-neox-20b',
15
+ 'openlm-research/open_llama_7b_700bt_preview',
16
+ 'decapoda-research/llama-7b-hf',
17
+ 'decapoda-research/llama-13b-hf',
18
+ 'decapoda-research/llama-30b-hf',
19
+ 'decapoda-research/llama-65b-hf',
20
+ 'facebook/mbart-large-50-many-to-many-mmt',
21
+ 'philschmid/bart-large-cnn-samsum',
22
+ 'philschmid/flan-t5-base-samsum',
23
+ 'gpt2',
24
+ 'distilgpt2',
25
+ 'mosaicml/mpt-7b-storywriter',
26
+ 'mosaicml/mpt-7b-instruct', # internal code handles instruct
27
+ 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
28
+ 'gptj', # internally handles prompting
29
+ 'llama', # plain, or need to choose prompt_type for given TheBloke model
30
+ 'gpt4all_llama', # internally handles prompting
31
+ ],
32
+ 'prompt_answer': [
33
+ 'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
34
+ 'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
35
+ 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
36
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
37
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
38
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
39
+ 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
40
+ 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
41
+ ],
42
+ 'instruct': [],
43
+ 'instruct_with_end': ['databricks/dolly-v2-12b'],
44
+ 'quality': [],
45
+ 'human_bot': [
46
+ 'h2oai/h2ogpt-oasst1-512-12b',
47
+ 'h2oai/h2ogpt-oasst1-512-20b',
48
+ 'h2oai/h2ogpt-oig-oasst1-256-6_9b',
49
+ 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
50
+ 'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
51
+ 'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
52
+ 'h2oai/h2ogpt-research-oasst1-512-30b',
53
+ 'h2oai/h2ogpt-oasst1-falcon-40b',
54
+ 'h2oai/h2ogpt-oig-oasst1-falcon-40b',
55
+ ],
56
+ 'dai_faq': [],
57
+ 'summarize': [],
58
+ 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
59
+ 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
60
+ 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
61
+ "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
62
+ "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
63
+ "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
64
+ "instruct_simple": ['JosephusCheung/Guanaco'],
65
+ }
66
+
67
+ inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
68
+ inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
69
+
70
+ prompt_types_strings = []
71
+ for p in PromptType:
72
+ prompt_types_strings.extend([p.name])
73
+
74
+ prompt_types = []
75
+ for p in PromptType:
76
+ prompt_types.extend([p.name, p.value, str(p.value)])
77
+
78
+
79
+ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=False):
80
+ prompt_dict_error = ''
81
+ if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
82
+ try:
83
+ prompt_dict = ast.literal_eval(prompt_dict)
84
+ except BaseException as e:
85
+ prompt_dict_error = str(e)
86
+ if prompt_dict_error:
87
+ return dict(), prompt_dict_error
88
+
89
+ if prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
90
+ PromptType.custom.name]:
91
+ promptA = prompt_dict.get('promptA', '')
92
+ promptB = prompt_dict('promptB', '')
93
+ PreInstruct = prompt_dict.get('PreInstruct', '')
94
+ PreInput = prompt_dict.get('PreInput', '')
95
+ PreResponse = prompt_dict.get('PreResponse', '')
96
+ terminate_response = prompt_dict.get('terminate_response', None)
97
+ chat_sep = prompt_dict.get('chat_sep', '\n')
98
+ humanstr = prompt_dict.get('humanstr', '')
99
+ botstr = prompt_dict.get('botstr', '')
100
+ elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
101
+ PromptType.plain.name]:
102
+ promptA = promptB = PreInstruct = PreInput = PreResponse = ''
103
+ terminate_response = []
104
+ chat_sep = ''
105
+ humanstr = ''
106
+ botstr = ''
107
+ elif prompt_type == 'simple_instruct':
108
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
109
+ terminate_response = []
110
+ chat_sep = '\n'
111
+ humanstr = ''
112
+ botstr = ''
113
+ elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
114
+ PromptType.instruct.name] + [PromptType.instruct_with_end.value,
115
+ str(PromptType.instruct_with_end.value),
116
+ PromptType.instruct_with_end.name]:
117
+ promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
118
+ chat and reduced) else ''
119
+ promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
120
+ chat and reduced) else ''
121
+
122
+ PreInstruct = """
123
+ ### Instruction:
124
+ """
125
+
126
+ PreInput = """
127
+ ### Input:
128
+ """
129
+
130
+ PreResponse = """
131
+ ### Response:
132
+ """
133
+ if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
134
+ PromptType.instruct_with_end.name]:
135
+ terminate_response = ['### End']
136
+ else:
137
+ terminate_response = None
138
+ chat_sep = '\n'
139
+ humanstr = PreInstruct
140
+ botstr = PreResponse
141
+ elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
142
+ PromptType.quality.name]:
143
+ promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
144
+ chat and reduced) else ''
145
+ promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
146
+ chat and reduced) else ''
147
+
148
+ PreInstruct = """
149
+ ### Instruction:
150
+ """
151
+
152
+ PreInput = """
153
+ ### Input:
154
+ """
155
+
156
+ PreResponse = """
157
+ ### Response:
158
+ """
159
+ terminate_response = None
160
+ chat_sep = '\n'
161
+ humanstr = PreInstruct # first thing human says
162
+ botstr = PreResponse # first thing bot says
163
+ elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
164
+ PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
165
+ str(PromptType.human_bot_orig.value),
166
+ PromptType.human_bot_orig.name]:
167
+ human = '<human>:'
168
+ bot = "<bot>:"
169
+ if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
170
+ PromptType.human_bot.name]:
171
+ preprompt = ''
172
+ else:
173
+ cur_date = time.strftime('%Y-%m-%d')
174
+ cur_time = time.strftime('%H:%M:%S %p %Z')
175
+
176
+ PRE_PROMPT = """\
177
+ Current Date: {}
178
+ Current Time: {}
179
+
180
+ """
181
+ preprompt = PRE_PROMPT.format(cur_date, cur_time)
182
+ start = human
183
+ promptB = promptA = '%s%s ' % (preprompt, start)
184
+
185
+ PreInstruct = ""
186
+
187
+ PreInput = None
188
+
189
+ if reduced:
190
+ # when making context, want it to appear as-if LLM generated, which starts with space after :
191
+ PreResponse = bot + ' '
192
+ else:
193
+ # normally LLM adds space after this, because was how trained.
194
+ # if add space here, non-unique tokenization will often make LLM produce wrong output
195
+ PreResponse = bot
196
+
197
+ terminate_response = [start, PreResponse]
198
+ chat_sep = '\n'
199
+ humanstr = human # tag before human talks
200
+ botstr = bot # tag before bot talks
201
+ elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
202
+ PromptType.dai_faq.name]:
203
+ promptA = ''
204
+ promptB = 'Answer the following Driverless AI question.\n'
205
+
206
+ PreInstruct = """
207
+ ### Driverless AI frequently asked question:
208
+ """
209
+
210
+ PreInput = None
211
+
212
+ PreResponse = """
213
+ ### Driverless AI documentation answer:
214
+ """
215
+ terminate_response = ['\n\n']
216
+ chat_sep = terminate_response
217
+ humanstr = PreInstruct
218
+ botstr = PreResponse
219
+ elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
220
+ PromptType.summarize.name]:
221
+ promptA = promptB = PreInput = ''
222
+ PreInstruct = '## Main Text\n\n'
223
+ PreResponse = '\n\n## Summary\n\n'
224
+ terminate_response = None
225
+ chat_sep = '\n'
226
+ humanstr = PreInstruct
227
+ botstr = PreResponse
228
+ elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
229
+ PromptType.instruct_vicuna.name]:
230
+ promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
231
+ "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
232
+ chat and reduced) else ''
233
+
234
+ PreInstruct = """
235
+ ### Human:
236
+ """
237
+
238
+ PreInput = None
239
+
240
+ PreResponse = """
241
+ ### Assistant:
242
+ """
243
+ terminate_response = [
244
+ '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
245
+ chat_sep = '\n'
246
+ humanstr = PreInstruct
247
+ botstr = PreResponse
248
+ elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
249
+ PromptType.prompt_answer.name]:
250
+ preprompt = ''
251
+ prompt_tokens = "<|prompt|>"
252
+ answer_tokens = "<|answer|>"
253
+ start = prompt_tokens
254
+ promptB = promptA = '%s%s' % (preprompt, start)
255
+ PreInstruct = ""
256
+ PreInput = None
257
+ PreResponse = answer_tokens
258
+ eos = '<|endoftext|>' # neox eos
259
+ terminate_response = [start, PreResponse, eos]
260
+ chat_sep = eos
261
+ humanstr = prompt_tokens
262
+ botstr = answer_tokens
263
+ elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
264
+ PromptType.open_assistant.name]:
265
+ # From added_tokens.json
266
+ preprompt = ''
267
+ prompt_tokens = "<|prompter|>"
268
+ answer_tokens = "<|assistant|>"
269
+ start = prompt_tokens
270
+ promptB = promptA = '%s%s' % (preprompt, start)
271
+ PreInstruct = ""
272
+ PreInput = None
273
+ PreResponse = answer_tokens
274
+ pend = "<|prefix_end|>"
275
+ eos = "</s>"
276
+ terminate_response = [start, PreResponse, pend, eos]
277
+ chat_sep = eos
278
+ humanstr = prompt_tokens
279
+ botstr = answer_tokens
280
+ elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
281
+ PromptType.wizard_lm.name]:
282
+ # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
283
+ preprompt = ''
284
+ start = ''
285
+ promptB = promptA = '%s%s' % (preprompt, start)
286
+ PreInstruct = ""
287
+ PreInput = None
288
+ PreResponse = "\n\n### Response\n"
289
+ eos = "</s>"
290
+ terminate_response = [PreResponse, eos]
291
+ chat_sep = eos
292
+ humanstr = promptA
293
+ botstr = PreResponse
294
+ elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
295
+ PromptType.wizard_mega.name]:
296
+ preprompt = ''
297
+ start = ''
298
+ promptB = promptA = '%s%s' % (preprompt, start)
299
+ PreInstruct = """
300
+ ### Instruction:
301
+ """
302
+ PreInput = None
303
+ PreResponse = """
304
+ ### Assistant:
305
+ """
306
+ terminate_response = [PreResponse]
307
+ chat_sep = '\n'
308
+ humanstr = PreInstruct
309
+ botstr = PreResponse
310
+ elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
311
+ PromptType.instruct_vicuna2.name]:
312
+ promptA = promptB = "" if not (
313
+ chat and reduced) else ''
314
+
315
+ PreInstruct = """
316
+ HUMAN:
317
+ """
318
+
319
+ PreInput = None
320
+
321
+ PreResponse = """
322
+ ASSISTANT:
323
+ """
324
+ terminate_response = [
325
+ 'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
326
+ chat_sep = '\n'
327
+ humanstr = PreInstruct
328
+ botstr = PreResponse
329
+ elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
330
+ PromptType.instruct_vicuna3.name]:
331
+ promptA = promptB = "" if not (
332
+ chat and reduced) else ''
333
+
334
+ PreInstruct = """
335
+ ### User:
336
+ """
337
+
338
+ PreInput = None
339
+
340
+ PreResponse = """
341
+ ### Assistant:
342
+ """
343
+ terminate_response = [
344
+ '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
345
+ chat_sep = '\n'
346
+ humanstr = PreInstruct
347
+ botstr = PreResponse
348
+ elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
349
+ PromptType.wizard2.name]:
350
+ # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
351
+ preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
352
+ start = ''
353
+ promptB = promptA = '%s%s' % (preprompt, start)
354
+ PreInstruct = """
355
+ ### Instruction:
356
+ """
357
+ PreInput = None
358
+ PreResponse = """
359
+ ### Response:
360
+ """
361
+ terminate_response = [PreResponse]
362
+ chat_sep = '\n'
363
+ humanstr = PreInstruct
364
+ botstr = PreResponse
365
+ elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
366
+ PromptType.wizard3.name]:
367
+ # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
368
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
369
+ start = ''
370
+ promptB = promptA = '%s%s' % (preprompt, start)
371
+ PreInstruct = """USER: """
372
+ PreInput = None
373
+ PreResponse = """ASSISTANT: """
374
+ terminate_response = [PreResponse]
375
+ chat_sep = '\n'
376
+ humanstr = PreInstruct
377
+ botstr = PreResponse
378
+
379
+ elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
380
+ PromptType.instruct_simple.name]:
381
+ promptA = '' if not (chat and reduced) else ''
382
+ promptB = '' if not (chat and reduced) else ''
383
+
384
+ PreInstruct = """
385
+ ### Instruction:
386
+ """
387
+
388
+ PreInput = """
389
+ ### Input:
390
+ """
391
+
392
+ PreResponse = """
393
+ ### Response:
394
+ """
395
+ terminate_response = None
396
+ chat_sep = '\n'
397
+ humanstr = PreInstruct
398
+ botstr = PreResponse
399
+ else:
400
+ raise RuntimeError("No such prompt_type=%s" % prompt_type)
401
+
402
+ if return_dict:
403
+ return dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
404
+ PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
405
+ humanstr=humanstr, botstr=botstr), ''
406
+ else:
407
+ return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
408
+
409
+
410
+ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
411
+ context = data_point.get('context')
412
+ if context is None:
413
+ context = ''
414
+ instruction = data_point.get('instruction')
415
+ input = data_point.get('input')
416
+ output = data_point.get('output')
417
+ prompt_type = data_point.get('prompt_type', prompt_type)
418
+ prompt_dict = data_point.get('prompt_dict', prompt_dict)
419
+ assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
420
+ promptA, promptB, PreInstruct, PreInput, PreResponse, \
421
+ terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, prompt_dict, chat, context, reduced)
422
+
423
+ prompt = context if not reduced else ''
424
+
425
+ if input and promptA:
426
+ prompt += f"""{promptA}"""
427
+ elif promptB:
428
+ prompt += f"""{promptB}"""
429
+
430
+ if instruction and PreInstruct is not None and input and PreInput is not None:
431
+ prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
432
+ prompt = inject_newline(prompt_type, prompt)
433
+ elif instruction and input and PreInstruct is None and PreInput is not None:
434
+ prompt += f"""{PreInput}{instruction}
435
+ {input}"""
436
+ prompt = inject_newline(prompt_type, prompt)
437
+ elif input and instruction and PreInput is None and PreInstruct is not None:
438
+ prompt += f"""{PreInstruct}{instruction}
439
+ {input}"""
440
+ prompt = inject_newline(prompt_type, prompt)
441
+ elif instruction and PreInstruct is not None:
442
+ prompt += f"""{PreInstruct}{instruction}"""
443
+ prompt = inject_newline(prompt_type, prompt)
444
+ elif input and PreInput is not None:
445
+ prompt += f"""{PreInput}{input}"""
446
+ prompt = inject_newline(prompt_type, prompt)
447
+ elif input and instruction and PreInput is not None:
448
+ prompt += f"""{PreInput}{instruction}{input}"""
449
+ prompt = inject_newline(prompt_type, prompt)
450
+ elif input and instruction and PreInstruct is not None:
451
+ prompt += f"""{PreInstruct}{instruction}{input}"""
452
+ prompt = inject_newline(prompt_type, prompt)
453
+ elif input and instruction:
454
+ # i.e. for simple_instruct
455
+ prompt += f"""{instruction}: {input}"""
456
+ prompt = inject_newline(prompt_type, prompt)
457
+ elif input:
458
+ prompt += f"""{input}"""
459
+ prompt = inject_newline(prompt_type, prompt)
460
+ elif instruction:
461
+ prompt += f"""{instruction}"""
462
+ prompt = inject_newline(prompt_type, prompt)
463
+
464
+ if PreResponse is not None:
465
+ prompt += f"""{PreResponse}"""
466
+ pre_response = PreResponse # Don't use strip
467
+ else:
468
+ pre_response = ''
469
+
470
+ if output:
471
+ prompt += f"""{output}"""
472
+
473
+ return prompt, pre_response, terminate_response, chat_sep
474
+
475
+
476
+ def inject_newline(prompt_type, prompt):
477
+ if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
478
+ # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
479
+ prompt += '\n'
480
+ return prompt
481
+
482
+
483
+ class Prompter(object):
484
+ def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True,
485
+ allowed_repeat_line_length=10):
486
+ self.prompt_type = prompt_type
487
+ self.prompt_dict = prompt_dict
488
+ data_point = dict(instruction='', input='', output='')
489
+ _, self.pre_response, self.terminate_response, self.chat_sep = \
490
+ generate_prompt(data_point, self.prompt_type, self.prompt_dict, chat, False)
491
+ self.debug = debug
492
+ self.chat = chat
493
+ self.stream_output = stream_output
494
+ self.repeat_penalty = repeat_penalty
495
+ self.allowed_repeat_line_length = allowed_repeat_line_length
496
+ self.prompt = None
497
+ context = "" # not for chat context
498
+ reduced = False # not for chat context
499
+ self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
500
+ self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
501
+ get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced)
502
+
503
+ def generate_prompt(self, data_point):
504
+ reduced = False
505
+ prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced)
506
+ if self.debug:
507
+ print("prompt: ", prompt, flush=True)
508
+ self.prompt = prompt
509
+ return prompt
510
+
511
+ def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
512
+ if isinstance(outputs, str):
513
+ outputs = [outputs]
514
+ if self.debug:
515
+ print("output:\n", '\n\n'.join(outputs), flush=True)
516
+ if prompt is not None:
517
+ self.prompt = prompt
518
+
519
+ def clean_response(response):
520
+ meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
521
+ for word in meaningless_words:
522
+ response = response.replace(word, "")
523
+ if sanitize_bot_response:
524
+ from better_profanity import profanity
525
+ response = profanity.censor(response)
526
+ response = response.strip("\n")
527
+ return response
528
+
529
+ def clean_repeats(response):
530
+ lines = response.split('\n')
531
+ new_lines = []
532
+ [new_lines.append(line) for line in lines if
533
+ line not in new_lines or len(line) < self.allowed_repeat_line_length]
534
+ if self.debug and len(lines) != len(new_lines):
535
+ print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
536
+ response = '\n'.join(new_lines)
537
+ return response
538
+
539
+ multi_output = len(outputs) > 1
540
+
541
+ for oi, output in enumerate(outputs):
542
+ if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
543
+ output = clean_response(output)
544
+ elif prompt is None:
545
+ # then use most basic parsing like pipeline
546
+ if self.botstr in output:
547
+ if self.humanstr:
548
+ output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
549
+ else:
550
+ # i.e. use after bot but only up to next bot
551
+ output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
552
+ else:
553
+ # output = clean_response(output.strip())
554
+ # assume just not printed yet
555
+ output = ""
556
+ else:
557
+ # find first instance of prereponse
558
+ # prompt sometimes has odd characters, that mutate length,
559
+ # so can't go by length alone
560
+ if self.pre_response:
561
+ outputi = output.find(prompt)
562
+ if outputi >= 0:
563
+ output = output[outputi + len(prompt):]
564
+ allow_terminate = True
565
+ else:
566
+ # subtraction is risky due to space offsets sometimes, so only do if necessary
567
+ output = output[len(prompt) - len(self.pre_response):]
568
+ # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
569
+ if self.pre_response in output:
570
+ output = output.split(self.pre_response)[1]
571
+ allow_terminate = True
572
+ else:
573
+ if output:
574
+ print("Failure of parsing or not enough output yet: %s" % output, flush=True)
575
+ allow_terminate = False
576
+ else:
577
+ allow_terminate = True
578
+ output = output[len(prompt):]
579
+ # clean after subtract prompt out, so correct removal of pre_response
580
+ output = clean_response(output).strip()
581
+ if self.repeat_penalty:
582
+ output = clean_repeats(output).strip()
583
+ if self.terminate_response and allow_terminate:
584
+ finds = []
585
+ for term in self.terminate_response:
586
+ finds.append(output.find(term))
587
+ finds = [x for x in finds if x >= 0]
588
+ if len(finds) > 0:
589
+ termi = finds[0]
590
+ output = output[:termi].strip()
591
+ else:
592
+ output = output.strip()
593
+ else:
594
+ output = output.strip()
595
+ if multi_output:
596
+ # prefix with output counter
597
+ output = "\n=========== Output %d\n\n" % (1 + oi) + output
598
+ if oi > 0:
599
+ # post fix outputs with seperator
600
+ output += '\n'
601
+ outputs[oi] = output
602
+ # join all outputs, only one extra new line between outputs
603
+ output = '\n'.join(outputs)
604
+ if self.debug:
605
+ print("outputclean:\n", '\n\n'.join(outputs), flush=True)
606
+ return output
requirements.txt CHANGED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for generate (gradio server) and finetune
2
+ datasets==2.12.0
3
+ sentencepiece==0.1.97
4
+ gradio==3.31.0
5
+ huggingface_hub==0.14.1
6
+ appdirs==1.4.4
7
+ fire==0.5.0
8
+ docutils==0.19
9
+ torch==2.0.1
10
+ evaluate==0.4.0
11
+ rouge_score==0.1.2
12
+ sacrebleu==2.3.1
13
+ scikit-learn==1.2.2
14
+ alt-profanity-check==1.2.2
15
+ better-profanity==0.6.1
16
+ numpy==1.24.2
17
+ pandas==2.0.0
18
+ matplotlib==3.7.1
19
+ loralib==0.1.1
20
+ bitsandbytes==0.39.0
21
+ accelerate==0.19.0
22
+ git+https://github.com/huggingface/peft.git@3714aa2fff158fdfa637b2b65952580801d890b2
23
+ transformers==4.28.1
24
+ tokenizers==0.13.3
25
+ APScheduler==3.10.1
26
+
27
+ # optional for generate
28
+ pynvml==11.5.0
29
+ psutil==5.9.4
30
+ boto3==1.26.101
31
+ botocore==1.29.101
32
+
33
+ # optional for finetune
34
+ tensorboard==2.12.1
35
+ neptune==1.1.1
36
+
37
+ # for gradio client
38
+ gradio_client==0.2.5
39
+ beautifulsoup4==4.12.2
40
+ markdown==3.4.1
41
+
42
+ # data and testing
43
+ pytest==7.2.2
44
+ pytest-xdist==3.2.1
45
+ nltk==3.8.1
46
+ textstat==0.7.3
47
+ pandoc==2.3
48
+ #pypandoc==1.11
49
+ pypandoc_binary==1.11
50
+ openpyxl==3.1.2
51
+ lm_dataformat==0.0.20
52
+ bioc==2.0
53
+
54
+ # falcon
55
+ einops==0.6.1
56
+ instructorembedding==1.0.1
57
+
58
+ # for gpt4all .env file, but avoid worrying about imports
59
+ python-dotenv==1.0.0# optional for chat with PDF
60
+ langchain==0.0.193
61
+ pypdf==3.8.1
62
+ tiktoken==0.3.3
63
+ # avoid textract, requires old six
64
+ #textract==1.6.5
65
+
66
+ # for HF embeddings
67
+ sentence_transformers==2.2.2
68
+ # for OpenAI embeddings (requires key)
69
+ openai==0.27.6
70
+
71
+ # local vector db
72
+ chromadb==0.3.25
73
+ # server vector db
74
+ #pymilvus==2.2.8
75
+
76
+ # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
77
+ # unstructured==0.6.6
78
+
79
+ # strong support for images
80
+ # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
81
+ unstructured[local-inference]==0.6.6
82
+ #pdf2image==1.16.3
83
+ #pytesseract==0.3.10
84
+ pillow
85
+
86
+ pdfminer.six==20221105
87
+ urllib3==1.26.6
88
+ requests_file==1.5.1
89
+
90
+ #pdf2image==1.16.3
91
+ #pytesseract==0.3.10
92
+ tabulate==0.9.0
93
+ # FYI pandoc already part of requirements.txt
94
+
95
+ # JSONLoader, but makes some trouble for some users
96
+ # jq==1.4.1
97
+
98
+ # to check licenses
99
+ # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
100
+ pip-licenses==4.3.0
101
+
102
+ # weaviate vector db
103
+ weaviate-client==3.19.2# optional for chat with PDF
104
+ langchain==0.0.193
105
+ pypdf==3.8.1
106
+ tiktoken==0.3.3
107
+ # avoid textract, requires old six
108
+ #textract==1.6.5
109
+
110
+ # for HF embeddings
111
+ sentence_transformers==2.2.2
112
+ # for OpenAI embeddings (requires key)
113
+ openai==0.27.6
114
+
115
+ # local vector db
116
+ chromadb==0.3.25
117
+ # server vector db
118
+ #pymilvus==2.2.8
119
+
120
+ # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
121
+ # unstructured==0.6.6
122
+
123
+ # strong support for images
124
+ # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
125
+ unstructured[local-inference]==0.6.6
126
+ #pdf2image==1.16.3
127
+ #pytesseract==0.3.10
128
+ pillow
129
+
130
+ pdfminer.six==20221105
131
+ urllib3==1.26.6
132
+ requests_file==1.5.1
133
+
134
+ #pdf2image==1.16.3
135
+ #pytesseract==0.3.10
136
+ tabulate==0.9.0
137
+ # FYI pandoc already part of requirements.txt
138
+
139
+ # JSONLoader, but makes some trouble for some users
140
+ # jq==1.4.1
141
+
142
+ # to check licenses
143
+ # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
144
+ pip-licenses==4.3.0
145
+
146
+ # weaviate vector db
147
+ weaviate-client==3.19.2faiss-gpu==1.7.2
148
+ gpt4all==0.2.3
149
+ llama-cpp-python==0.1.55
150
+ arxiv==1.4.7
151
+ pymupdf==1.22.3 # AGPL license
152
+ # extract-msg==0.41.1 # GPL3
stopping.py DELETED
@@ -1 +0,0 @@
1
- ../../stopping.py
 
 
stopping.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import StoppingCriteria, StoppingCriteriaList
3
+
4
+ from prompter import PromptType
5
+
6
+
7
+ class StoppingCriteriaSub(StoppingCriteria):
8
+
9
+ def __init__(self, stops=[], encounters=[], device="cuda"):
10
+ super().__init__()
11
+ assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
12
+ self.encounters = encounters
13
+ self.stops = [stop.to(device) for stop in stops]
14
+ self.num_stops = [0] * len(stops)
15
+
16
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
+ for stopi, stop in enumerate(self.stops):
18
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
19
+ self.num_stops[stopi] += 1
20
+ if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
21
+ # print("Stopped", flush=True)
22
+ return True
23
+ # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
24
+ # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
25
+ return False
26
+
27
+
28
+ def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:"):
29
+ # FIXME: prompt_dict unused currently
30
+ if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
31
+ if prompt_type == PromptType.human_bot.name:
32
+ # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
33
+ # stopping only starts once output is beyond prompt
34
+ # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
35
+ stop_words = [human, bot, '\n' + human, '\n' + bot]
36
+ encounters = [1, 2]
37
+ elif prompt_type == PromptType.instruct_vicuna.name:
38
+ # even below is not enough, generic strings and many ways to encode
39
+ stop_words = [
40
+ '### Human:',
41
+ """
42
+ ### Human:""",
43
+ """
44
+ ### Human:
45
+ """,
46
+ '### Assistant:',
47
+ """
48
+ ### Assistant:""",
49
+ """
50
+ ### Assistant:
51
+ """,
52
+ ]
53
+ encounters = [1, 2]
54
+ else:
55
+ # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
56
+ stop_words = ['### End']
57
+ encounters = [1]
58
+ stop_words_ids = [
59
+ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
60
+ # handle single token case
61
+ stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
62
+ stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
63
+ # avoid padding in front of tokens
64
+ if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
65
+ stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
66
+ # handle fake \n added
67
+ stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
68
+ # build stopper
69
+ stopping_criteria = StoppingCriteriaList(
70
+ [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
71
+ else:
72
+ stopping_criteria = StoppingCriteriaList()
73
+ return stopping_criteria
utils.py DELETED
@@ -1 +0,0 @@
1
- ../../utils.py
 
 
utils.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import functools
3
+ import hashlib
4
+ import inspect
5
+ import os
6
+ import gc
7
+ import pathlib
8
+ import random
9
+ import shutil
10
+ import subprocess
11
+ import sys
12
+ import threading
13
+ import time
14
+ import traceback
15
+ import zipfile
16
+ from datetime import datetime
17
+ from enum import Enum
18
+
19
+ import filelock
20
+ import requests, uuid
21
+ from typing import Tuple, Callable, Dict
22
+ from tqdm.auto import tqdm
23
+ from joblib import Parallel
24
+ from concurrent.futures import ProcessPoolExecutor
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+
29
+ def set_seed(seed: int):
30
+ """
31
+ Sets the seed of the entire notebook so results are the same every time we run.
32
+ This is for REPRODUCIBILITY.
33
+ """
34
+ import torch
35
+ np.random.seed(seed)
36
+ random_state = np.random.RandomState(seed)
37
+ random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ torch.cuda.manual_seed(seed)
40
+ torch.backends.cudnn.deterministic = True
41
+ torch.backends.cudnn.benchmark = False
42
+ os.environ['PYTHONHASHSEED'] = str(seed)
43
+ return random_state
44
+
45
+
46
+ def flatten_list(lis):
47
+ """Given a list, possibly nested to any level, return it flattened."""
48
+ new_lis = []
49
+ for item in lis:
50
+ if type(item) == type([]):
51
+ new_lis.extend(flatten_list(item))
52
+ else:
53
+ new_lis.append(item)
54
+ return new_lis
55
+
56
+
57
+ def clear_torch_cache():
58
+ import torch
59
+ if torch.cuda.is_available():
60
+ torch.cuda.empty_cache()
61
+ torch.cuda.ipc_collect()
62
+ gc.collect()
63
+
64
+
65
+ def ping():
66
+ try:
67
+ print('Ping: %s' % str(datetime.now()), flush=True)
68
+ except AttributeError:
69
+ # some programs wrap print and will fail with flush passed
70
+ pass
71
+
72
+
73
+ def get_torch_allocated():
74
+ import torch
75
+ return torch.cuda.memory_allocated()
76
+
77
+
78
+ def get_device():
79
+ import torch
80
+ if torch.cuda.is_available():
81
+ device = "cuda"
82
+ else:
83
+ device = "cpu"
84
+
85
+ return device
86
+
87
+
88
+ def system_info():
89
+ import psutil
90
+
91
+ system = {}
92
+ # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
93
+ # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
94
+ temps = psutil.sensors_temperatures(fahrenheit=False)
95
+ if 'coretemp' in temps:
96
+ coretemp = temps['coretemp']
97
+ temp_dict = {k.label: k.current for k in coretemp}
98
+ for k, v in temp_dict.items():
99
+ system['CPU_C/%s' % k] = v
100
+
101
+ # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
102
+ from pynvml.smi import nvidia_smi
103
+ nvsmi = nvidia_smi.getInstance()
104
+
105
+ gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
106
+ enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
107
+ for k, v in gpu_power_dict.items():
108
+ system['GPU_W/%s' % k] = v
109
+
110
+ gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
111
+ enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
112
+ for k, v in gpu_temp_dict.items():
113
+ system['GPU_C/%s' % k] = v
114
+
115
+ gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
116
+ enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
117
+ gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
118
+ enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
119
+ gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
120
+ for k, v in gpu_memory_frac_dict.items():
121
+ system[f'GPU_M/%s' % k] = v
122
+
123
+ system['hash'] = get_githash()
124
+
125
+ return system
126
+
127
+
128
+ def system_info_print():
129
+ try:
130
+ df = pd.DataFrame.from_dict(system_info(), orient='index')
131
+ # avoid slamming GPUs
132
+ time.sleep(1)
133
+ return df.to_markdown()
134
+ except Exception as e:
135
+ return "Error: %s" % str(e)
136
+
137
+
138
+ def zip_data(root_dirs=None, zip_file=None, base_dir='./', fail_any_exception=False):
139
+ try:
140
+ return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
141
+ except Exception as e:
142
+ traceback.print_exc()
143
+ print('Exception in zipping: %s' % str(e))
144
+ if not fail_any_exception:
145
+ raise
146
+
147
+
148
+ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
149
+ if isinstance(root_dirs, str):
150
+ root_dirs = [root_dirs]
151
+ if zip_file is None:
152
+ datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
153
+ host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
154
+ zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
155
+ assert root_dirs is not None
156
+ if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname(zip_file):
157
+ os.makedirs(os.path.dirname(zip_file), exist_ok=True)
158
+ with zipfile.ZipFile(zip_file, "w") as expt_zip:
159
+ for root_dir in root_dirs:
160
+ if root_dir is None:
161
+ continue
162
+ for root, d, files in os.walk(root_dir):
163
+ for file in files:
164
+ file_to_archive = os.path.join(root, file)
165
+ assert os.path.exists(file_to_archive)
166
+ path_to_archive = os.path.relpath(file_to_archive, base_dir)
167
+ expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
168
+ return zip_file, zip_file
169
+
170
+
171
+ def save_generate_output(output=None, base_model=None, save_dir=None):
172
+ try:
173
+ return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
174
+ except Exception as e:
175
+ traceback.print_exc()
176
+ print('Exception in saving: %s' % str(e))
177
+
178
+
179
+ def _save_generate_output(output=None, base_model=None, save_dir=None):
180
+ """
181
+ Save conversation to .json, row by row.
182
+ json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
183
+ Appends if file exists
184
+ """
185
+ assert save_dir, "save_dir must be provided"
186
+ if os.path.exists(save_dir) and not os.path.isdir(save_dir):
187
+ raise RuntimeError("save_dir already exists and is not a directory!")
188
+ os.makedirs(save_dir, exist_ok=True)
189
+ import json
190
+ if output[-10:] == '\n\n<human>:':
191
+ # remove trailing <human>:
192
+ output = output[:-10]
193
+ with filelock.FileLock("save_dir.lock"):
194
+ # lock logging in case have concurrency
195
+ with open(os.path.join(save_dir, "history.json"), "a") as f:
196
+ # just add [ at start, and ] at end, and have proper JSON dataset
197
+ f.write(
198
+ " " + json.dumps(
199
+ dict(text=output, time=time.ctime(), base_model=base_model)
200
+ ) + ",\n"
201
+ )
202
+
203
+
204
+ def s3up(filename):
205
+ try:
206
+ return _s3up(filename)
207
+ except Exception as e:
208
+ traceback.print_exc()
209
+ print('Exception for file %s in s3up: %s' % (filename, str(e)))
210
+ return "Failed to upload %s: Error: %s" % (filename, str(e))
211
+
212
+
213
+ def _s3up(filename):
214
+ import boto3
215
+
216
+ aws_access_key_id = os.getenv('AWS_SERVER_PUBLIC_KEY')
217
+ aws_secret_access_key = os.getenv('AWS_SERVER_SECRET_KEY')
218
+ bucket = os.getenv('AWS_BUCKET')
219
+ assert aws_access_key_id, "Set AWS key"
220
+ assert aws_secret_access_key, "Set AWS secret"
221
+ assert bucket, "Set AWS Bucket"
222
+
223
+ s3 = boto3.client('s3',
224
+ aws_access_key_id=os.getenv('AWS_SERVER_PUBLIC_KEY'),
225
+ aws_secret_access_key=os.getenv('AWS_SERVER_SECRET_KEY'),
226
+ )
227
+ ret = s3.upload_file(
228
+ Filename=filename,
229
+ Bucket=os.getenv('AWS_BUCKET'),
230
+ Key=filename,
231
+ )
232
+ if ret in [None, '']:
233
+ return "Successfully uploaded %s" % filename
234
+
235
+
236
+ def get_githash():
237
+ try:
238
+ githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
239
+ except:
240
+ githash = ''
241
+ return githash
242
+
243
+
244
+ def copy_code(run_id):
245
+ """
246
+ copy code to track changes
247
+ :param run_id:
248
+ :return:
249
+ """
250
+ rnd_num = str(random.randint(0, 2 ** 31))
251
+ run_id = 'run_' + str(run_id)
252
+ os.makedirs(run_id, exist_ok=True)
253
+ me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
254
+ me_file = os.path.basename(__file__)
255
+ new_me = os.path.join(run_id, me_file + '_' + get_githash())
256
+ if os.path.isfile(new_me):
257
+ new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
258
+ shutil.copy(me_full, new_me)
259
+ else:
260
+ shutil.copy(me_full, new_me)
261
+
262
+
263
+ class NullContext(threading.local):
264
+ """No-op context manager, executes block without doing any additional processing.
265
+
266
+ Used as a stand-in if a particular block of code is only sometimes
267
+ used with a normal context manager:
268
+ """
269
+
270
+ def __init__(self, *args, **kwargs):
271
+ pass
272
+
273
+ def __enter__(self):
274
+ return self
275
+
276
+ def __exit__(self, exc_type, exc_value, exc_traceback):
277
+ self.finally_act()
278
+
279
+ def finally_act(self):
280
+ pass
281
+
282
+
283
+ def wrapped_partial(func, *args, **kwargs):
284
+ """
285
+ Give partial properties of normal function, like __name__ attribute etc.
286
+ :param func:
287
+ :param args:
288
+ :param kwargs:
289
+ :return:
290
+ """
291
+ partial_func = functools.partial(func, *args, **kwargs)
292
+ functools.update_wrapper(partial_func, func)
293
+ return partial_func
294
+
295
+
296
+ class ThreadException(Exception):
297
+ pass
298
+
299
+
300
+ class EThread(threading.Thread):
301
+ # Function that raises the custom exception
302
+ def __init__(self, group=None, target=None, name=None,
303
+ args=(), kwargs=None, *, daemon=None, streamer=None, bucket=None):
304
+ self.bucket = bucket
305
+ self.streamer = streamer
306
+ self.exc = None
307
+ self._return = None
308
+ super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
309
+
310
+ def run(self):
311
+ # Variable that stores the exception, if raised by someFunction
312
+ try:
313
+ if self._target is not None:
314
+ self._return = self._target(*self._args, **self._kwargs)
315
+ except BaseException as e:
316
+ print("thread exception: %s" % str(sys.exc_info()))
317
+ self.bucket.put(sys.exc_info())
318
+ self.exc = e
319
+ if self.streamer:
320
+ print("make stop: %s" % str(sys.exc_info()), flush=True)
321
+ self.streamer.do_stop = True
322
+ finally:
323
+ # Avoid a refcycle if the thread is running a function with
324
+ # an argument that has a member that points to the thread.
325
+ del self._target, self._args, self._kwargs
326
+
327
+ def join(self, timeout=None):
328
+ threading.Thread.join(self)
329
+ # Since join() returns in caller thread
330
+ # we re-raise the caught exception
331
+ # if any was caught
332
+ if self.exc:
333
+ raise self.exc
334
+ return self._return
335
+
336
+
337
+ def import_matplotlib():
338
+ import matplotlib
339
+ matplotlib.use('agg')
340
+ # KEEP THESE HERE! START
341
+ import matplotlib.pyplot as plt
342
+ import pandas as pd
343
+ # to avoid dlopen deadlock in fork
344
+ import pandas.core.computation.expressions as pd_expressions
345
+ import pandas._libs.groupby as pd_libgroupby
346
+ import pandas._libs.reduction as pd_libreduction
347
+ import pandas.core.algorithms as pd_algorithms
348
+ import pandas.core.common as pd_com
349
+ import numpy as np
350
+ # KEEP THESE HERE! END
351
+
352
+
353
+ def get_sha(value):
354
+ return hashlib.md5(str(value).encode('utf-8')).hexdigest()
355
+
356
+
357
+ def sanitize_filename(name):
358
+ """
359
+ Sanitize file *base* names.
360
+ :param name: name to sanitize
361
+ :return:
362
+ """
363
+ bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^']
364
+ for char in bad_chars:
365
+ name = name.replace(char, "_")
366
+
367
+ length = len(name)
368
+ file_length_limit = 250 # bit smaller than 256 for safety
369
+ sha_length = 32
370
+ real_length_limit = file_length_limit - (sha_length + 2)
371
+ if length > file_length_limit:
372
+ sha = get_sha(name)
373
+ half_real_length_limit = max(1, int(real_length_limit / 2))
374
+ name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length]
375
+
376
+ return name
377
+
378
+
379
+ def shutil_rmtree(*args, **kwargs):
380
+ return shutil.rmtree(*args, **kwargs)
381
+
382
+
383
+ def remove(path: str):
384
+ try:
385
+ if path is not None and os.path.exists(path):
386
+ if os.path.isdir(path):
387
+ shutil_rmtree(path, ignore_errors=True)
388
+ else:
389
+ with contextlib.suppress(FileNotFoundError):
390
+ os.remove(path)
391
+ except:
392
+ pass
393
+
394
+
395
+ def makedirs(path, exist_ok=True):
396
+ """
397
+ Avoid some inefficiency in os.makedirs()
398
+ :param path:
399
+ :param exist_ok:
400
+ :return:
401
+ """
402
+ if os.path.isdir(path) and os.path.exists(path):
403
+ assert exist_ok, "Path already exists"
404
+ return path
405
+ os.makedirs(path, exist_ok=exist_ok)
406
+
407
+
408
+ def atomic_move_simple(src, dst):
409
+ try:
410
+ shutil.move(src, dst)
411
+ except (shutil.Error, FileExistsError):
412
+ pass
413
+ remove(src)
414
+
415
+
416
+ def download_simple(url, dest=None, print_func=None):
417
+ if print_func is not None:
418
+ print_func("BEGIN get url %s" % str(url))
419
+ if url.startswith("file://"):
420
+ from requests_file import FileAdapter
421
+ s = requests.Session()
422
+ s.mount('file://', FileAdapter())
423
+ url_data = s.get(url, stream=True)
424
+ else:
425
+ url_data = requests.get(url, stream=True)
426
+ if dest is None:
427
+ dest = os.path.basename(url)
428
+ if url_data.status_code != requests.codes.ok:
429
+ msg = "Cannot get url %s, code: %s, reason: %s" % (
430
+ str(url),
431
+ str(url_data.status_code),
432
+ str(url_data.reason),
433
+ )
434
+ raise requests.exceptions.RequestException(msg)
435
+ url_data.raw.decode_content = True
436
+ makedirs(os.path.dirname(dest), exist_ok=True)
437
+ uuid_tmp = str(uuid.uuid4())[:6]
438
+ dest_tmp = dest + "_dl_" + uuid_tmp + ".tmp"
439
+ with open(dest_tmp, "wb") as f:
440
+ shutil.copyfileobj(url_data.raw, f)
441
+ atomic_move_simple(dest_tmp, dest)
442
+ if print_func is not None:
443
+ print_func("END get url %s" % str(url))
444
+
445
+
446
+ def download(url, dest=None, dest_path=None):
447
+ if dest_path is not None:
448
+ dest = os.path.join(dest_path, os.path.basename(url))
449
+ if os.path.isfile(dest):
450
+ print("already downloaded %s -> %s" % (url, dest))
451
+ return dest
452
+ elif dest is not None:
453
+ if os.path.exists(dest):
454
+ print("already downloaded %s -> %s" % (url, dest))
455
+ return dest
456
+ else:
457
+ uuid_tmp = "dl2_" + str(uuid.uuid4())[:6]
458
+ dest = uuid_tmp + os.path.basename(url)
459
+
460
+ print("downloading %s to %s" % (url, dest))
461
+
462
+ if url.startswith("file://"):
463
+ from requests_file import FileAdapter
464
+ s = requests.Session()
465
+ s.mount('file://', FileAdapter())
466
+ url_data = s.get(url, stream=True)
467
+ else:
468
+ url_data = requests.get(url, stream=True)
469
+
470
+ if url_data.status_code != requests.codes.ok:
471
+ msg = "Cannot get url %s, code: %s, reason: %s" % (
472
+ str(url), str(url_data.status_code), str(url_data.reason))
473
+ raise requests.exceptions.RequestException(msg)
474
+ url_data.raw.decode_content = True
475
+ dirname = os.path.dirname(dest)
476
+ if dirname != "" and not os.path.isdir(dirname):
477
+ makedirs(os.path.dirname(dest), exist_ok=True)
478
+ uuid_tmp = "dl3_" + str(uuid.uuid4())[:6]
479
+ dest_tmp = dest + "_" + uuid_tmp + ".tmp"
480
+ with open(dest_tmp, 'wb') as f:
481
+ shutil.copyfileobj(url_data.raw, f)
482
+ try:
483
+ shutil.move(dest_tmp, dest)
484
+ except FileExistsError:
485
+ pass
486
+ remove(dest_tmp)
487
+ return dest
488
+
489
+
490
+ def get_url(x, from_str=False, short_name=False):
491
+ if not from_str:
492
+ source = x.metadata['source']
493
+ else:
494
+ source = x
495
+ if short_name:
496
+ source_name = get_short_name(source)
497
+ else:
498
+ source_name = source
499
+ if source.startswith('http://') or source.startswith('https://'):
500
+ return """<a href="%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
501
+ source, source_name)
502
+ else:
503
+ return """<a href="file/%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
504
+ source, source_name)
505
+
506
+
507
+ def get_short_name(name, maxl=50):
508
+ if name is None:
509
+ return ''
510
+ length = len(name)
511
+ if length > maxl:
512
+ allow_length = maxl - 3
513
+ half_allowed = max(1, int(allow_length / 2))
514
+ name = name[0:half_allowed] + "..." + name[length - half_allowed:length]
515
+ return name
516
+
517
+
518
+ def cuda_vis_check(total_gpus):
519
+ """Helper function to count GPUs by environment variable
520
+ Stolen from Jon's h2o4gpu utils
521
+ """
522
+ cudavis = os.getenv("CUDA_VISIBLE_DEVICES")
523
+ which_gpus = []
524
+ if cudavis is not None:
525
+ # prune away white-space, non-numerics,
526
+ # except commas for simple checking
527
+ cudavis = "".join(cudavis.split())
528
+ import re
529
+ cudavis = re.sub("[^0-9,]", "", cudavis)
530
+
531
+ lencudavis = len(cudavis)
532
+ if lencudavis == 0:
533
+ total_gpus = 0
534
+ else:
535
+ total_gpus = min(
536
+ total_gpus,
537
+ os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1)
538
+ which_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(",")
539
+ which_gpus = [int(x) for x in which_gpus]
540
+ else:
541
+ which_gpus = list(range(0, total_gpus))
542
+
543
+ return total_gpus, which_gpus
544
+
545
+
546
+ def get_ngpus_vis(raise_if_exception=True):
547
+ ngpus_vis1 = 0
548
+
549
+ shell = False
550
+ if shell:
551
+ cmd = "nvidia-smi -L 2> /dev/null"
552
+ else:
553
+ cmd = ["nvidia-smi", "-L"]
554
+
555
+ try:
556
+ timeout = 5 * 3
557
+ o = subprocess.check_output(cmd, shell=shell, timeout=timeout)
558
+ lines = o.decode("utf-8").splitlines()
559
+ ngpus_vis1 = 0
560
+ for line in lines:
561
+ if 'Failed to initialize NVML' not in line:
562
+ ngpus_vis1 += 1
563
+ except (FileNotFoundError, subprocess.CalledProcessError, OSError):
564
+ # GPU systems might not have nvidia-smi, so can't fail
565
+ pass
566
+ except subprocess.TimeoutExpired as e:
567
+ print('Failed get_ngpus_vis: %s' % str(e))
568
+ if raise_if_exception:
569
+ raise
570
+
571
+ ngpus_vis1, which_gpus = cuda_vis_check(ngpus_vis1)
572
+ return ngpus_vis1
573
+
574
+
575
+ def get_mem_gpus(raise_if_exception=True, ngpus=None):
576
+ totalmem_gpus1 = 0
577
+ usedmem_gpus1 = 0
578
+ freemem_gpus1 = 0
579
+
580
+ if ngpus == 0:
581
+ return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
582
+
583
+ try:
584
+ cmd = "nvidia-smi -q 2> /dev/null | grep -A 3 'FB Memory Usage'"
585
+ o = subprocess.check_output(cmd, shell=True, timeout=15)
586
+ lines = o.decode("utf-8").splitlines()
587
+ for line in lines:
588
+ if 'Total' in line:
589
+ totalmem_gpus1 += int(line.split()[2]) * 1024 ** 2
590
+ if 'Used' in line:
591
+ usedmem_gpus1 += int(line.split()[2]) * 1024 ** 2
592
+ if 'Free' in line:
593
+ freemem_gpus1 += int(line.split()[2]) * 1024 ** 2
594
+ except (FileNotFoundError, subprocess.CalledProcessError, OSError):
595
+ # GPU systems might not have nvidia-smi, so can't fail
596
+ pass
597
+ except subprocess.TimeoutExpired as e:
598
+ print('Failed get_mem_gpus: %s' % str(e))
599
+ if raise_if_exception:
600
+ raise
601
+
602
+ return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
603
+
604
+
605
+ class ForkContext(threading.local):
606
+ """
607
+ Set context for forking
608
+ Ensures state is returned once done
609
+ """
610
+
611
+ def __init__(self, args=None, kwargs=None, forkdata_capable=True):
612
+ """
613
+ :param args:
614
+ :param kwargs:
615
+ :param forkdata_capable: whether fork is forkdata capable and will use copy-on-write forking of args/kwargs
616
+ """
617
+ self.forkdata_capable = forkdata_capable
618
+ if self.forkdata_capable:
619
+ self.has_args = args is not None
620
+ self.has_kwargs = kwargs is not None
621
+ forkdatacontext.args = args
622
+ forkdatacontext.kwargs = kwargs
623
+ else:
624
+ self.has_args = False
625
+ self.has_kwargs = False
626
+
627
+ def __enter__(self):
628
+ try:
629
+ # flush all outputs so doesn't happen during fork -- don't print/log inside ForkContext contexts!
630
+ sys.stdout.flush()
631
+ sys.stderr.flush()
632
+ except BaseException as e:
633
+ # exit not called if exception, and don't want to leave forkdatacontext filled in that case
634
+ print("ForkContext failure on enter: %s" % str(e))
635
+ self.finally_act()
636
+ raise
637
+ return self
638
+
639
+ def __exit__(self, exc_type, exc_value, exc_traceback):
640
+ self.finally_act()
641
+
642
+ def finally_act(self):
643
+ """
644
+ Done when exception hit or exit is reached in context
645
+ first reset forkdatacontext as crucial to have reset even if later 2 calls fail
646
+ :return: None
647
+ """
648
+ if self.forkdata_capable and (self.has_args or self.has_kwargs):
649
+ forkdatacontext._reset()
650
+
651
+
652
+ class _ForkDataContext(threading.local):
653
+ def __init__(
654
+ self,
655
+ args=None,
656
+ kwargs=None,
657
+ ):
658
+ """
659
+ Global context for fork to carry data to subprocess instead of relying upon copy/pickle/serialization
660
+
661
+ :param args: args
662
+ :param kwargs: kwargs
663
+ """
664
+ assert isinstance(args, (tuple, type(None)))
665
+ assert isinstance(kwargs, (dict, type(None)))
666
+ self.__args = args
667
+ self.__kwargs = kwargs
668
+
669
+ @property
670
+ def args(self) -> Tuple:
671
+ """returns args"""
672
+ return self.__args
673
+
674
+ @args.setter
675
+ def args(self, args):
676
+ if self.__args is not None:
677
+ raise AttributeError(
678
+ "args cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
679
+ )
680
+
681
+ self.__args = args
682
+
683
+ @property
684
+ def kwargs(self) -> Dict:
685
+ """returns kwargs"""
686
+ return self.__kwargs
687
+
688
+ @kwargs.setter
689
+ def kwargs(self, kwargs):
690
+ if self.__kwargs is not None:
691
+ raise AttributeError(
692
+ "kwargs cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
693
+ )
694
+
695
+ self.__kwargs = kwargs
696
+
697
+ def _reset(self):
698
+ """Reset fork arg-kwarg context to default values"""
699
+ self.__args = None
700
+ self.__kwargs = None
701
+
702
+ def get_args_kwargs(self, func, args, kwargs) -> Tuple[Callable, Tuple, Dict]:
703
+ if self.__args:
704
+ args = self.__args[1:]
705
+ if not func:
706
+ assert len(self.__args) > 0, "if have no func, must have in args"
707
+ func = self.__args[0] # should always be there
708
+ if self.__kwargs:
709
+ kwargs = self.__kwargs
710
+ try:
711
+ return func, args, kwargs
712
+ finally:
713
+ forkdatacontext._reset()
714
+
715
+ @staticmethod
716
+ def get_args_kwargs_for_traced_func(func, args, kwargs):
717
+ """
718
+ Return args/kwargs out of forkdatacontext when using copy-on-write way of passing args/kwargs
719
+ :param func: actual function ran by _traced_func, which itself is directly what mppool treats as function
720
+ :param args:
721
+ :param kwargs:
722
+ :return: func, args, kwargs from forkdatacontext if used, else originals
723
+ """
724
+ # first 3 lines are debug
725
+ func_was_None = func is None
726
+ args_was_None_or_empty = args is None or len(args) == 0
727
+ kwargs_was_None_or_empty = kwargs is None or len(kwargs) == 0
728
+
729
+ forkdatacontext_args_was_None = forkdatacontext.args is None
730
+ forkdatacontext_kwargs_was_None = forkdatacontext.kwargs is None
731
+ func, args, kwargs = forkdatacontext.get_args_kwargs(func, args, kwargs)
732
+ using_forkdatacontext = func_was_None and func is not None # pulled func out of forkdatacontext.__args[0]
733
+ assert forkdatacontext.args is None, "forkdatacontext.args should be None after get_args_kwargs"
734
+ assert forkdatacontext.kwargs is None, "forkdatacontext.kwargs should be None after get_args_kwargs"
735
+
736
+ proc_type = kwargs.get('proc_type', 'SUBPROCESS')
737
+ if using_forkdatacontext:
738
+ assert proc_type == "SUBPROCESS" or proc_type == "SUBPROCESS"
739
+ if proc_type == "NORMAL":
740
+ assert forkdatacontext_args_was_None, "if no fork, expect forkdatacontext.args None entering _traced_func"
741
+ assert forkdatacontext_kwargs_was_None, "if no fork, expect forkdatacontext.kwargs None entering _traced_func"
742
+ assert func is not None, "function should not be None, indicates original args[0] was None or args was None"
743
+
744
+ return func, args, kwargs
745
+
746
+
747
+ forkdatacontext = _ForkDataContext()
748
+
749
+
750
+ def _traced_func(func, *args, **kwargs):
751
+ func, args, kwargs = forkdatacontext.get_args_kwargs_for_traced_func(func, args, kwargs)
752
+ return func(*args, **kwargs)
753
+
754
+
755
+ def call_subprocess_onetask(func, args=None, kwargs=None):
756
+ if isinstance(args, list):
757
+ args = tuple(args)
758
+ if args is None:
759
+ args = ()
760
+ if kwargs is None:
761
+ kwargs = {}
762
+ args = list(args)
763
+ args = [func] + args
764
+ args = tuple(args)
765
+ with ForkContext(args=args, kwargs=kwargs):
766
+ args = (None,)
767
+ kwargs = {}
768
+ with ProcessPoolExecutor(max_workers=1) as executor:
769
+ future = executor.submit(_traced_func, *args, **kwargs)
770
+ return future.result()
771
+
772
+
773
+ class ProgressParallel(Parallel):
774
+ def __init__(self, use_tqdm=True, total=None, *args, **kwargs):
775
+ self._use_tqdm = use_tqdm
776
+ self._total = total
777
+ super().__init__(*args, **kwargs)
778
+
779
+ def __call__(self, *args, **kwargs):
780
+ with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
781
+ return Parallel.__call__(self, *args, **kwargs)
782
+
783
+ def print_progress(self):
784
+ if self._total is None:
785
+ self._pbar.total = self.n_dispatched_tasks
786
+ self._pbar.n = self.n_completed_tasks
787
+ self._pbar.refresh()
788
+
789
+
790
+ def get_kwargs(func, exclude_names=None, **kwargs):
791
+ func_names = list(inspect.signature(func).parameters)
792
+ missing_kwargs = [x for x in func_names if x not in kwargs]
793
+ if exclude_names:
794
+ for k in exclude_names:
795
+ if k in missing_kwargs:
796
+ missing_kwargs.remove(k)
797
+ if k in func_names:
798
+ func_names.remove(k)
799
+ assert not missing_kwargs, "Missing %s" % missing_kwargs
800
+ kwargs = {k: v for k, v in kwargs.items() if k in func_names}
801
+ return kwargs
802
+
803
+
804
+ import pkg_resources
805
+ have_faiss = False
806
+
807
+ try:
808
+ assert pkg_resources.get_distribution('faiss') is not None
809
+ have_faiss = True
810
+ except (pkg_resources.DistributionNotFound, AssertionError):
811
+ pass
812
+ try:
813
+ assert pkg_resources.get_distribution('faiss_gpu') is not None
814
+ have_faiss = True
815
+ except (pkg_resources.DistributionNotFound, AssertionError):
816
+ pass
817
+ try:
818
+ assert pkg_resources.get_distribution('faiss_cpu') is not None
819
+ have_faiss = True
820
+ except (pkg_resources.DistributionNotFound, AssertionError):
821
+ pass
822
+
823
+
824
+ def hash_file(file):
825
+ try:
826
+ import hashlib
827
+
828
+ # BUF_SIZE is totally arbitrary, change for your app!
829
+ BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
830
+
831
+ md5 = hashlib.md5()
832
+ #sha1 = hashlib.sha1()
833
+
834
+ with open(file, 'rb') as f:
835
+ while True:
836
+ data = f.read(BUF_SIZE)
837
+ if not data:
838
+ break
839
+ md5.update(data)
840
+ #sha1.update(data)
841
+ except BaseException as e:
842
+ print("Cannot hash %s due to %s" % (file, str(e)))
843
+ traceback.print_exc()
844
+ md5 = None
845
+ return md5.hexdigest()
846
+
847
+
848
+ def start_faulthandler():
849
+ # If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
850
+ # If more than one fork tries to write at same time, then looks corrupted.
851
+ import faulthandler
852
+ import signal
853
+
854
+ # SIGUSR1 in h2oai/__init__.py as well
855
+ faulthandler.enable()
856
+ faulthandler.register(signal.SIGUSR1)