pseudotensor commited on
Commit
3a1a60a
1 Parent(s): 0f993c6

Update with h2oGPT hash 13a8343d2a96885985bda8c4480bbb23cf55bb9b

Browse files
Files changed (33) hide show
  1. LICENSE +0 -201
  2. LICENSE +1 -0
  3. client_test.py +0 -179
  4. client_test.py +1 -0
  5. create_data.py +0 -1809
  6. create_data.py +1 -0
  7. enums.py +1 -0
  8. finetune.py +0 -670
  9. finetune.py +1 -0
  10. generate.py +0 -1548
  11. generate.py +1 -0
  12. gpt4all_llm.py +0 -255
  13. gpt4all_llm.py +1 -0
  14. gpt_langchain.py +0 -1471
  15. gpt_langchain.py +1 -0
  16. gradio_runner.py +0 -1741
  17. gradio_runner.py +1 -0
  18. gradio_themes.py +0 -183
  19. gradio_themes.py +1 -0
  20. gradio_ui +1 -0
  21. h2o-logo.svg +0 -1
  22. h2o-logo.svg +1 -0
  23. h2oai_pipeline.py +0 -128
  24. h2oai_pipeline.py +1 -0
  25. loaders.py +0 -50
  26. loaders.py +1 -0
  27. prompter.py +0 -576
  28. prompter.py +1 -0
  29. requirements.txt +0 -100
  30. stopping.py +0 -72
  31. stopping.py +1 -0
  32. utils.py +0 -843
  33. utils.py +1 -0
LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../LICENSE
client_test.py DELETED
@@ -1,179 +0,0 @@
1
- """
2
- Client test.
3
-
4
- Run server:
5
-
6
- python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
7
-
8
- NOTE: For private models, add --use-auth_token=True
9
-
10
- NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
- Currently, this will force model to be on a single GPU.
12
-
13
- Then run this client as:
14
-
15
- python client_test.py
16
-
17
-
18
-
19
- For HF spaces:
20
-
21
- HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
22
-
23
- Result:
24
-
25
- Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
26
- {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
27
-
28
-
29
- For demo:
30
-
31
- HOST="https://gpt.h2o.ai" python client_test.py
32
-
33
- Result:
34
-
35
- Loaded as API: https://gpt.h2o.ai ✔
36
- {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
37
-
38
- NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
39
-
40
- {'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
41
-
42
-
43
- """
44
- import ast
45
- import time
46
- import os
47
- import markdown # pip install markdown
48
- import pytest
49
- from bs4 import BeautifulSoup # pip install beautifulsoup4
50
-
51
- debug = False
52
-
53
- os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
54
-
55
-
56
- def get_client(serialize=True):
57
- from gradio_client import Client
58
-
59
- client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize)
60
- if debug:
61
- print(client.view_api(all_endpoints=True))
62
- return client
63
-
64
-
65
- def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50, langchain_mode='Disabled'):
66
- from collections import OrderedDict
67
- kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
68
- iinput='', # only for chat=True
69
- context='',
70
- # streaming output is supported, loops over and outputs each generation in streaming mode
71
- # but leave stream_output=False for simple input/output mode
72
- stream_output=stream_output,
73
- prompt_type=prompt_type,
74
- temperature=0.1,
75
- top_p=0.75,
76
- top_k=40,
77
- num_beams=1,
78
- max_new_tokens=max_new_tokens,
79
- min_new_tokens=0,
80
- early_stopping=False,
81
- max_time=20,
82
- repetition_penalty=1.0,
83
- num_return_sequences=1,
84
- do_sample=True,
85
- chat=chat,
86
- instruction_nochat=prompt if not chat else '',
87
- iinput_nochat='', # only for chat=False
88
- langchain_mode=langchain_mode,
89
- top_k_docs=4,
90
- document_choice=['All'],
91
- )
92
- if chat:
93
- # add chatbot output on end. Assumes serialize=False
94
- kwargs.update(dict(chatbot=[]))
95
-
96
- return kwargs, list(kwargs.values())
97
-
98
-
99
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
100
- def test_client_basic():
101
- return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
102
-
103
-
104
- def run_client_nochat(prompt, prompt_type, max_new_tokens):
105
- kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
106
-
107
- api_name = '/submit_nochat'
108
- client = get_client(serialize=True)
109
- res = client.predict(
110
- *tuple(args),
111
- api_name=api_name,
112
- )
113
- print("Raw client result: %s" % res, flush=True)
114
- res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
115
- response=md_to_text(ast.literal_eval(res)['response']),
116
- sources=ast.literal_eval(res)['sources'])
117
- print(res_dict)
118
- return res_dict
119
-
120
-
121
- @pytest.mark.skip(reason="For manual use against some server, no server launched")
122
- def test_client_chat():
123
- return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50,
124
- langchain_mode='Disabled')
125
-
126
-
127
- def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
128
- client = get_client(serialize=False)
129
-
130
- kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
131
- max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
132
- return run_client(client, prompt, args, kwargs)
133
-
134
-
135
- def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
136
- res = client.predict(*tuple(args), api_name='/instruction')
137
- args[-1] += [res[-1]]
138
-
139
- res_dict = kwargs
140
- res_dict['prompt'] = prompt
141
- if not kwargs['stream_output']:
142
- res = client.predict(*tuple(args), api_name='/instruction_bot')
143
- res_dict['response'] = res[0][-1][1]
144
- print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
145
- return res_dict, client
146
- else:
147
- job = client.submit(*tuple(args), api_name='/instruction_bot')
148
- res1 = ''
149
- while not job.done():
150
- outputs_list = job.communicator.job.outputs
151
- if outputs_list:
152
- res = job.communicator.job.outputs[-1]
153
- res1 = res[0][-1][-1]
154
- res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
155
- print(res1)
156
- time.sleep(0.1)
157
- full_outputs = job.outputs()
158
- if verbose:
159
- print('job.outputs: %s' % str(full_outputs))
160
- # ensure get ending to avoid race
161
- # -1 means last response if streaming
162
- # 0 means get text_output, ignore exception_text
163
- # 0 means get list within text_output that looks like [[prompt], [answer]]
164
- # 1 means get bot answer, so will have last bot answer
165
- res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
166
- return res_dict, client
167
-
168
-
169
- def md_to_text(md, do_md_to_text=True):
170
- if not do_md_to_text:
171
- return md
172
- assert md is not None, "Markdown is None"
173
- html = markdown.markdown(md)
174
- soup = BeautifulSoup(html, features='html.parser')
175
- return soup.get_text()
176
-
177
-
178
- if __name__ == '__main__':
179
- test_client_basic()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
client_test.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../client_test.py
create_data.py DELETED
@@ -1,1809 +0,0 @@
1
- """
2
- Dataset creation tools.
3
-
4
- Keep to-level imports clean of non-trivial imports for specific tools,
5
- because this file is imported for various purposes
6
- """
7
-
8
- import ast
9
- import concurrent.futures
10
- import contextlib
11
- import hashlib
12
- import json
13
- import os
14
- import shutil
15
- import signal
16
- import sys
17
- import traceback
18
- from concurrent.futures import ProcessPoolExecutor
19
-
20
- import psutil
21
- import pytest
22
- import pandas as pd
23
- import numpy as np
24
- from tqdm import tqdm
25
-
26
- from utils import flatten_list, remove
27
-
28
-
29
- def parse_rst_file(filepath):
30
- with open(filepath, 'r') as f:
31
- input_data = f.read()
32
- settings_overrides = {'initial_header_level': 2}
33
- from docutils import core
34
- document = core.publish_doctree(
35
- source=input_data,
36
- source_path=filepath,
37
- settings_overrides=settings_overrides,
38
- )
39
- qa_pairs = []
40
- current_section = None
41
- current_question = ""
42
- current_answer = ""
43
- for node in document.traverse():
44
- if node.__class__.__name__ == 'section':
45
- current_section = ""
46
- elif current_section is not None:
47
- if node.__class__.__name__ == 'Text':
48
- if node.astext()[-1] == "?":
49
- if current_question:
50
- qa_pairs.append((current_question, current_answer))
51
- current_question = node.astext()
52
- current_answer = ""
53
- else:
54
- current_answer += node.astext()
55
- if current_answer:
56
- qa_pairs.append((current_question, current_answer))
57
- return {k: v for k, v in qa_pairs}
58
-
59
-
60
- def test_scrape_dai_docs():
61
- home = os.path.expanduser('~')
62
- file = os.path.join(home, 'h2oai/docs/faq.rst')
63
- qa_pairs = parse_rst_file(file)
64
- prompt_type = 'human_bot'
65
- from prompter import prompt_types
66
- assert prompt_type in prompt_types
67
- save_thing = [{"instruction": k, "output": v, 'prompt_type': prompt_type} for k, v in qa_pairs.items()]
68
- output_file = "dai_faq.json"
69
- with open(output_file, "wt") as f:
70
- f.write(json.dumps(save_thing, indent=2))
71
-
72
-
73
- def test_scrape_dai_docs_all():
74
- """
75
- pytest create_data.py::test_scrape_dai_docs_all
76
- """
77
- import glob
78
- import nltk
79
- nltk.download('punkt')
80
- dd = {}
81
- np.random.seed(1234)
82
- home = os.path.expanduser('~')
83
- files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst")))
84
- np.random.shuffle(files)
85
- val_count = int(0.05 * len(files))
86
- train_files = files[val_count:]
87
- valid_files = files[:val_count]
88
- things = [
89
- ("dai_docs.train.json", train_files),
90
- ("dai_docs.valid.json", valid_files)
91
- ]
92
- for LEN in [100, 200, 500]:
93
- for output_file, ff in things:
94
- if output_file not in dd:
95
- dd[output_file] = []
96
- for f in ff:
97
- with open(f) as input:
98
- blob = input.read()
99
- blob = blob.replace("~~", "")
100
- blob = blob.replace("==", "")
101
- blob = blob.replace("''", "")
102
- blob = blob.replace("--", "")
103
- blob = blob.replace("**", "")
104
- dd[output_file].extend(get_sentences(blob, length=LEN))
105
- for output_file, _ in things:
106
- save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in dd[output_file]]
107
- with open(output_file, "wt") as f:
108
- f.write(json.dumps(save_thing, indent=2))
109
-
110
-
111
- def get_sentences(blob, length):
112
- """
113
- break-up input text into sentences and then output list of sentences of about length in size
114
- :param blob:
115
- :param length:
116
- :return:
117
- """
118
- import nltk
119
- nltk.download('punkt')
120
- from nltk.tokenize import sent_tokenize
121
- sentences = sent_tokenize(blob)
122
- my_sentences = []
123
- my_string = ""
124
- for sentence in sentences:
125
- if len(my_string) + len(sentence) <= length:
126
- if my_string:
127
- my_string += " " + sentence
128
- else:
129
- my_string = sentence
130
- else:
131
- my_sentences.append(my_string)
132
- my_string = ""
133
- return my_sentences or [my_string]
134
-
135
-
136
- def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
137
- """
138
- Only supported if have access to source code or HF token for HF spaces and from_hf=True
139
- :param path:
140
- :param dst:
141
- :param from_hf:
142
- :return:
143
- """
144
-
145
- home = os.path.expanduser('~')
146
-
147
- if from_hf:
148
- # assumes
149
- from huggingface_hub import hf_hub_download
150
- # True for case when locally already logged in with correct token, so don't have to set key
151
- token = os.getenv('HUGGINGFACE_API_TOKEN', True)
152
- path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.zip', token=token, repo_type='dataset')
153
- path = 'h2oai'
154
- import zipfile
155
- with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
156
- zip_ref.extractall(path)
157
- path = os.path.join(path, 'docs/**/*')
158
-
159
- if path is None:
160
- if os.path.isdir(os.path.join(home, 'h2oai')):
161
- path = os.path.join(home, "h2oai/docs/**/*")
162
- else:
163
- assert os.path.isdir(os.path.join(home, 'h2oai.superclean')), '%s does not exist' % path
164
- path = os.path.join(home, "h2oai.superclean/docs/**/*")
165
- import glob
166
- files = list(glob.glob(path, recursive=True))
167
-
168
- # pandoc can't find include files
169
-
170
- remove(dst)
171
- os.makedirs(dst)
172
-
173
- # copy full tree, for absolute paths in rst
174
- for fil in files:
175
- if os.path.isfile(fil):
176
- shutil.copy(fil, dst)
177
-
178
- # hack for relative path
179
- scorers_dir = os.path.join(dst, 'scorers')
180
- makedirs(scorers_dir)
181
- for fil in glob.glob(os.path.join(dst, '*.frag')):
182
- shutil.copy(fil, scorers_dir)
183
-
184
- return dst
185
-
186
-
187
- def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30):
188
- # account for sequence length (context window) including prompt and input and output
189
-
190
- # os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
191
- import pypandoc
192
- basedir = os.path.abspath(os.getcwd())
193
-
194
- outputs = []
195
- for fil in files:
196
- os.chdir(basedir)
197
- os.chdir(os.path.dirname(fil))
198
- fil = os.path.basename(fil)
199
- print("Processing %s" % fil, flush=True)
200
- # out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x,
201
- # context, csljson, docbook, docbook4, docbook5, docx, dokuwiki,
202
- # dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml,
203
- # ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira,
204
- # json, latex, man,
205
- # markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict,
206
- # mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx,
207
- # revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki
208
- out_format = 'plain'
209
- # avoid extra new lines injected into text
210
- extra_args = ['--wrap=preserve', '--resource path="%s" % dst']
211
-
212
- plain_list = []
213
- try:
214
- # valid for expert settings
215
- input_rst = pypandoc.convert_file(fil, 'rst')
216
- input_list = input_rst.split('\n``')
217
- for input_subrst in input_list:
218
- input_plain = pypandoc.convert_text(input_subrst, format='rst', to='plain')
219
- plain_list.append([input_plain, fil])
220
- except Exception as e:
221
- print("file exception: %s %s" % (fil, str(e)), flush=True)
222
-
223
- if not plain_list:
224
- # if failed to process as pieces of rst, then
225
- output = pypandoc.convert_file(fil, out_format, extra_args=extra_args, format='rst')
226
- outputs1 = get_sentences(output, length=max_len)
227
- for oi, output in enumerate(outputs1):
228
- output = output.replace('\n\n', '\n')
229
- plain_list.append([output, fil])
230
- outputs.extend(plain_list)
231
-
232
- # report:
233
- # [print(len(x)) for x in outputs]
234
-
235
- # deal with blocks longer than context size (sequence length) of 2048
236
- new_outputs = []
237
- num_truncated = 0
238
- num_orig = len(outputs)
239
- for output, fil in outputs:
240
- if len(output) < max_len:
241
- new_outputs.append([output, fil])
242
- continue
243
- outputs1 = get_sentences(output, length=max_len)
244
- for oi, output1 in enumerate(outputs1):
245
- output1 = output1.replace('\n\n', '\n')
246
- new_outputs.append([output1, fil])
247
- num_truncated += 1
248
- print('num_orig: %s num_truncated: %s' % (num_orig, num_truncated), flush=True)
249
-
250
- new_outputs = [[k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len]
251
-
252
- return new_outputs
253
-
254
-
255
- def test_scrape_dai_docs_all_pandoc():
256
- """
257
- pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc
258
- :return:
259
- """
260
-
261
- dst = setup_dai_docs()
262
-
263
- import glob
264
- files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
265
-
266
- basedir = os.path.abspath(os.getcwd())
267
- new_outputs = rst_to_outputs(files)
268
- os.chdir(basedir)
269
-
270
- remove(dst)
271
- save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in new_outputs]
272
- output_file = "dai_docs.train_cleaned.json"
273
- with open(output_file, "wt") as f:
274
- f.write(json.dumps(save_thing, indent=2))
275
-
276
-
277
- def test_config_to_json():
278
- """
279
- Needs to run from Driverless AI source directory.
280
- E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/
281
- :return:
282
- """
283
- try:
284
- # Arrange
285
- import json
286
- from h2oaicore.systemutils import config
287
- toml_list = []
288
- for k, v in config.get_meta_dict().items():
289
- title = (v.title + ": ") if v.title else ''
290
- comment = v.comment or ''
291
- if not (title or comment):
292
- continue
293
- toml_list.extend(
294
- [
295
- {
296
- 'prompt_type': 'plain',
297
- 'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
298
- "\n", ""),
299
- },
300
- {
301
- 'prompt_type': 'plain',
302
- 'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
303
- "\n", ""),
304
- },
305
- {
306
- 'prompt_type': 'plain',
307
- 'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace(
308
- "\n", ""),
309
- } if title and comment else None,
310
- {
311
- 'prompt_type': 'human_bot',
312
- 'instruction': f'Explain the following expert setting for Driverless AI',
313
- 'input': f"{k}",
314
- 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
315
- },
316
- {
317
- 'prompt_type': 'human_bot',
318
- 'instruction': f'Explain the following expert setting for Driverless AI',
319
- 'input': f"{k}",
320
- 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
321
- },
322
- {
323
- 'prompt_type': 'human_bot',
324
- 'instruction': f'Explain the following expert setting for Driverless AI',
325
- 'input': f"{k.replace('_', ' ')}",
326
- 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
327
- },
328
- {
329
- 'prompt_type': 'human_bot',
330
- 'instruction': f'Explain the following expert setting for Driverless AI',
331
- 'input': f"{title}",
332
- 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
333
- },
334
- {
335
- 'prompt_type': 'human_bot',
336
- 'instruction': f'Provide a short explanation of the expert setting {k}',
337
- 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
338
- },
339
- {
340
- 'prompt_type': 'human_bot',
341
- 'instruction': f'Provide a detailed explanation of the expert setting {k}',
342
- 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
343
- },
344
- ]
345
- )
346
- toml_list = [x for x in toml_list if x]
347
- with open("config.json", "wt") as f:
348
- f.write(json.dumps(toml_list, indent=2))
349
- except Exception as e:
350
- print("Exception: %s" % str(e), flush=True)
351
-
352
-
353
- def copy_tree(src, dst, follow_symlink=False):
354
- makedirs(dst, exist_ok=True)
355
- for (path, dirs, files) in os.walk(src, followlinks=follow_symlink):
356
- new_path = path.replace(src, dst)
357
- makedirs(new_path, exist_ok=True)
358
- for file in files:
359
- filename = os.path.join(path, file)
360
- new_filename = os.path.join(new_path, file)
361
- # print("%s -> %s" % (filename, new_filename))
362
- try:
363
- atomic_copy(filename, new_filename)
364
- except FileNotFoundError:
365
- pass
366
-
367
-
368
- def atomic_move(src, dst):
369
- try:
370
- shutil.move(src, dst)
371
- except (shutil.Error, FileExistsError):
372
- pass
373
- remove(src)
374
-
375
-
376
- def atomic_copy(src=None, dst=None, with_permissions=True):
377
- if os.path.isfile(dst):
378
- return
379
- import uuid
380
- my_uuid = uuid.uuid4()
381
- dst_tmp = dst + str(my_uuid)
382
- makedirs(os.path.dirname(dst), exist_ok=True)
383
- if with_permissions:
384
- shutil.copy(src, dst_tmp)
385
- else:
386
- shutil.copyfile(src, dst_tmp)
387
- atomic_move(dst_tmp, dst)
388
- remove(dst_tmp)
389
-
390
-
391
- def makedirs(path, exist_ok=True):
392
- """
393
- Avoid some inefficiency in os.makedirs()
394
- :param path:
395
- :param exist_ok:
396
- :return:
397
- """
398
- if os.path.isdir(path) and os.path.exists(path):
399
- assert exist_ok, "Path already exists"
400
- return path
401
- os.makedirs(path, exist_ok=exist_ok)
402
-
403
-
404
- ## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json
405
- ## Turn into simple instruct prompt type. No context/previous conversations.
406
- def test_prep_instruct_vicuna():
407
- from datasets import load_dataset
408
- filename = 'ShareGPT_unfiltered_cleaned_split.json'
409
- if not os.path.exists(filename):
410
- os.system(
411
- 'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
412
- data = load_dataset("json", data_files={"train": filename})["train"]
413
- training_rows = []
414
- for i in range(data.num_rows):
415
- conversations = data[i]['conversations']
416
- assert isinstance(conversations, list), conversations
417
- convo = ""
418
- for j, conv in enumerate(conversations):
419
- # Get ready for generate.py prompt_type=human_bot
420
- # But train with prompt_type=plain
421
- if conv['from'] == 'human':
422
- FROM = '<human>: '
423
- elif conv['from'] == 'gpt':
424
- FROM = '<bot>: '
425
- convo += f"{FROM}" + conv['value'] + "\n"
426
- if convo:
427
- training_rows.append(dict(input=convo))
428
- with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
429
- f.write(json.dumps(training_rows, indent=2))
430
-
431
-
432
- POSTFIX = ".generate_human_bot.train_plain.json"
433
-
434
- # https://bair.berkeley.edu/blog/2023/04/03/koala/
435
- OIG_DATASETS = [
436
- "unified_chip2.jsonl",
437
- "unified_grade_school_math_instructions.jsonl",
438
- "unified_poetry_2_song.jsonl",
439
- "unified_plot_screenplay_books_dialog.jsonl",
440
- ]
441
-
442
- # hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4
443
- ALL_OIG_DATASETS = ['unified_abstract_infill.jsonl',
444
- 'unified_basic.jsonl',
445
- 'unified_canadian_parliament.jsonl',
446
- 'unified_chip2.jsonl',
447
- 'unified_conv_finqa.jsonl',
448
- 'unified_cuad.jsonl',
449
- 'unified_essays.jsonl',
450
- 'unified_flan.jsonl.gz',
451
- 'unified_grade_school_math_instructions.jsonl',
452
- 'unified_hc3_human.jsonl',
453
- 'unified_image_prompts_instructions.jsonl',
454
- 'unified_joke_explanations.jsonl',
455
- 'unified_mathqa_flanv2_kojma_cot.jsonl',
456
- 'unified_merged_code_xp3.jsonl',
457
- 'unified_multi_news.jsonl',
458
- 'unified_multi_sum.jsonl',
459
- 'unified_ni.jsonl.gz',
460
- 'unified_nq.jsonl',
461
- 'unified_openai_summarize_tldr.jsonl',
462
- 'unified_oscar_en_sample_dialog.jsonl',
463
- 'unified_p3.jsonl.gz',
464
- 'unified_plot_screenplay_books_dialog.jsonl',
465
- 'unified_poetry_2_song.jsonl',
466
- 'unified_poetry_instructions.jsonl',
467
- 'unified_rallio_safety_and_prosocial.jsonl',
468
- 'unified_rallio_soda_upgraded_2048.jsonl',
469
- 'unified_soda_dialog.jsonl',
470
- 'unified_sqlv1.jsonl',
471
- 'unified_sqlv2.jsonl',
472
- 'unified_squad_v2.jsonl',
473
- 'unified_squad_v2_more_neg.jsonl',
474
- 'unified_ul2_plus_oscar_en_sample_dialog.jsonl',
475
- 'unified_unifiedskg_instructions.jsonl',
476
- 'unified_unnatural_instructions.jsonl',
477
- 'unified_xp3_sample.jsonl']
478
-
479
- useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
480
- 'unified_chip2.jsonl.parquet',
481
- 'unified_cuad.jsonl.parquet',
482
- 'unified_essays.jsonl.parquet',
483
- 'unified_flan.jsonl.gz.parquet',
484
- 'unified_grade_school_math_instructions.jsonl.parquet',
485
- 'unified_hc3_human.jsonl.parquet',
486
- 'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
487
- 'unified_merged_code_xp3.jsonl.parquet',
488
- 'unified_multi_news.jsonl.parquet',
489
- # 'unified_multi_sum.jsonl.parquet'
490
- 'unified_ni.jsonl.gz.parquet',
491
- 'unified_openai_summarize_tldr.jsonl.parquet',
492
- # 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
493
- 'unified_plot_screenplay_books_dialog.jsonl.parquet',
494
- 'unified_soda_dialog.jsonl.parquet',
495
- 'unified_unnatural_instructions.jsonl.parquet',
496
- ]
497
-
498
-
499
- @pytest.mark.parametrize("filename", OIG_DATASETS)
500
- def test_get_small_sample_oig_data(filename):
501
- if not os.path.exists(filename):
502
- os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
503
- import json
504
- rows = []
505
- with open(filename, "r") as f:
506
- for line in f.readlines():
507
- row = json.loads(line)
508
- rows.append(dict(input=row["text"]))
509
- with open(filename + POSTFIX, "w") as f:
510
- f.write(json.dumps(rows, indent=2))
511
-
512
-
513
- @pytest.mark.parametrize("filename", ALL_OIG_DATASETS)
514
- def test_download_useful_data_as_parquet(filename):
515
- dest_file = filename + '.parquet'
516
- if dest_file not in useful_oig_files:
517
- pytest.skip('file declared not useful')
518
- if not os.path.exists(filename):
519
- os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
520
- if not os.path.exists(dest_file):
521
- df = pd.read_json(path_or_buf=filename, lines=True)
522
- df.to_parquet(dest_file, index=False)
523
-
524
-
525
- def test_merge_shuffle_small_sample_oig_data():
526
- np.random.seed(1234)
527
- rows = []
528
- for filename in OIG_DATASETS:
529
- with open(filename + POSTFIX, "r") as f:
530
- rows.extend(json.loads(f.read()))
531
- np.random.shuffle(rows)
532
- with open("merged_shuffled_OIG_%s.json" % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], "w") as f:
533
- f.write(json.dumps(rows, indent=2))
534
-
535
-
536
- def test_join_jsons():
537
- files = ['config.json'] * 1 + \
538
- ['dai_docs.train_cleaned.json'] * 2 + \
539
- ['dai_faq.json'] * 3
540
- print(files)
541
- lst = []
542
- [lst.extend(json.load(open(fil, 'rt'))) for fil in files]
543
- print(len(lst))
544
- json.dump(lst, open("merged.json", "wt"), indent=2)
545
-
546
-
547
- @pytest.mark.parametrize("filename", ['Anthropic/hh-rlhf'])
548
- def test_make_rlhf_good_data(filename):
549
- from datasets import load_dataset
550
- rows = load_dataset(filename)["train"]["chosen"]
551
- new_rows = []
552
- for row in rows:
553
- if row[:2] == "\n\n":
554
- row = row[2:]
555
- row = row.replace("Human: ", "<human>: ")
556
- row = row.replace("Assistant: ", "<bot>: ")
557
- new_rows.append(dict(input=row))
558
- with open(filename.replace("/", "_") + POSTFIX, "w") as f:
559
- f.write(json.dumps(new_rows, indent=2))
560
-
561
-
562
- def test_show_prompts():
563
- files = ['config.json'] * 1 + \
564
- ['dai_docs.train_cleaned.json'] * 1 + \
565
- ['dai_faq.json'] * 1
566
- file_points = [json.load(open(fil, 'rt')) for fil in files]
567
- from prompter import generate_prompt
568
- for data_points in file_points:
569
- for data_point in data_points:
570
- print(generate_prompt(data_point, 'plain', False, False)[0])
571
-
572
-
573
- def test_get_open_datasets():
574
- # HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter
575
- open_tags = ['license:Apache License 2.0',
576
- 'license:mit',
577
- 'license:apache',
578
- 'license:apache2',
579
- 'license:apache-2.0',
580
- 'license:bsd',
581
- 'license:bsd-2-clause',
582
- 'license:bsd-3-clause',
583
- 'license:bsd-3-clause-clear',
584
- 'license:lgpl-2.1',
585
- 'license:lgpl-3.0',
586
- 'license:lgpl-lr',
587
- 'license:lgpl',
588
- 'license:openrail++',
589
- 'license:openrail',
590
- 'license:bigscience-bloom-rail-1.0',
591
- # 'license:agpl-3.0',
592
- 'license:other',
593
- 'license:unknown',
594
- # 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
595
- # Attribution required:
596
- 'license:odc-by',
597
- 'license:cc-by-4.0',
598
- 'license:cc-by-3.0',
599
- 'license:cc-by-2.0',
600
- 'license:cc-by-2.5',
601
- # 'license:cc-by-sa-4.0', # would require same license
602
- 'license:odbl',
603
- 'license:pddl',
604
- 'license:ms-pl',
605
- 'license:zlib',
606
- ]
607
- # bad license: cc-by-nc-4.0
608
-
609
- from huggingface_hub import list_datasets
610
- datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
611
- datasets += [x for x in list_datasets(author='openai')]
612
- # check all:
613
- all_license_tags = set(flatten_list([[y for y in x.tags if 'license' in y] for x in datasets]))
614
- print(len(all_license_tags))
615
- open_datasets = [x for x in datasets if any([y in x.tags for y in open_tags]) or 'license:' not in str(x.tags)]
616
- print('open_datasets', len(open_datasets))
617
- all_task_tags = set(flatten_list([[y for y in x.tags if 'task' in y] for x in open_datasets]))
618
- print('all_task_tags', len(all_task_tags))
619
- excluded_tags = ['image', 'hate', 'tabular', 'table-', 'classification', 'retrieval',
620
- 'translation', 'identification', 'object', 'mask', 'to-text',
621
- 'face-detection', 'audio', 'voice', 'reinforcement', 'depth-est',
622
- 'forecasting', 'parsing', 'visual', 'speech', 'multiple-choice',
623
- 'slot-filling', 'irds/argsme', '-scoring', 'other', 'graph-ml',
624
- 'feature-extraction', 'keyword-spotting',
625
- 'coreference-resolution', 'segmentation',
626
- 'word-sense-disambiguation',
627
- 'lemmatization']
628
- task_tags = [x.replace('task_categories:', '').replace('task_ids:', '')
629
- for x in all_task_tags if not any([y in x for y in
630
- excluded_tags])]
631
- print('task_tags', len(task_tags))
632
- # str(x.tags) to catch any pattern match to anything in list
633
- open_tasked_datasets = [x for x in open_datasets if
634
- any([y in str([x for x in x.tags if 'task' in x]) for y in task_tags]) and
635
- not any([y in str([x for x in x.tags if 'task' in x]) for y in excluded_tags]) or
636
- 'task_categories' not in str(x.tags) and 'task_ids' not in str(x.tags)]
637
- open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled]
638
- open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated]
639
- open_tasked_datasets = [x for x in open_tasked_datasets if not x.private]
640
- print('open_tasked_datasets', len(open_tasked_datasets))
641
- sizes = list(set(flatten_list([[(y, x.id) for y in x.tags if 'size' in y] for x in open_tasked_datasets])))
642
- languages = list(set(flatten_list([[(y, x.id) for y in x.tags if 'language:' in y] for x in open_tasked_datasets])))
643
- open_english_tasked_datasets = [x for x in open_tasked_datasets if
644
- 'language:' not in str(x.tags) or
645
- 'language:en' in str(x.tags)]
646
- small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
647
- 'n<1K' in str(x.tags) or
648
- '1K<n<10K' in str(x.tags) or
649
- '1K0<n<100K' in str(x.tags) or
650
- '100K<n<1M' in str(x.tags) or
651
- 'size_category' not in str(x.tags)
652
- ]
653
- # 'aeslc' : email_body, subject -> summarization?
654
- # load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
655
- ids = [x.id for x in small_open_english_tasked_datasets]
656
-
657
- # sanity checks
658
- # https://bair.berkeley.edu/blog/2023/04/03/koala/
659
- assert 'alespalla/chatbot_instruction_prompts' in ids
660
- assert 'laion/OIG' in ids
661
- assert 'openai/webgpt_comparisons' in ids
662
- assert 'openai/summarize_from_feedback' in ids
663
- assert 'Anthropic/hh-rlhf' in ids
664
-
665
- # useful but not allowed for commercial purposes:
666
- # https://huggingface.co/datasets/squad
667
-
668
- print('open_english_tasked_datasets: ', ids, flush=True)
669
-
670
- exclude_ids = ['allenai/nllb', # translation only
671
- 'hf-internal-testing/fixtures_image_utils', # testing
672
- 'allenai/c4', # search-url
673
- 'agemagician/uniref50', # unknown
674
- 'huggingface-course/documentation-images', # images
675
- 'smilegate-ai/kor_unsmile', # korean
676
- 'MohamedRashad/ChatGPT-prompts', # ChatGPT/LearnGPT/https://www.emergentmind.com/
677
- 'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
678
- 'Jeska/vaccinchat', # not useful
679
- 'alespalla/chatbot_instruction_prompts', # mixes alpaca
680
- 'allenai/prosocial-dialog',
681
- # already exlucded, but wrongly in other datasets that say more permissive license
682
- 'AlekseyKorshuk/persona-chat', # low quality
683
- 'bavard/personachat_truecased', # low quality
684
- 'adamlin/daily_dialog', # medium quality conversations
685
- 'adamlin/FewShotWoz', # low quality
686
- 'benjaminbeilharz/better_daily_dialog', # low quality
687
- 'benjaminbeilharz/daily_dialog_w_turn_templates', # low
688
- 'benjaminbeilharz/empathetic_dialogues_for_lm', # low
689
- 'GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915', # NA
690
- 'ia-bentebib/conv_ai_2_fr', # low fr
691
- 'ia-bentebib/daily_dialog_fr', # low fr
692
- 'ia-bentebib/dialog_re_fr', # low fr
693
- 'ia-bentebib/empathetic_dialogues_fr', # low fr
694
- 'roskoN/dailydialog', # low
695
- 'VadorMazer/skyrimdialogstest', # low
696
- 'bigbio/med_qa', # med specific Q/A
697
- 'biu-nlp/qa_srl2018', # low quality Q/A
698
- 'biu-nlp/qa_discourse', # low quality Q/A
699
- 'iarfmoose/qa_evaluator', # low quality Q/A
700
- 'jeopardy', # low quality Q/A -- no reasoning
701
- 'narrativeqa', # low quality Q/A
702
- 'nomic-ai/gpt4all_prompt_generations', # bad license
703
- 'nomic-ai/gpt4all_prompt_generations_with_p3', # bad license
704
- 'HuggingFaceH4/alpaca', # bad license
705
- 'tatsu-lab/alpaca', # ToS breaking
706
- 'yahma/alpaca-cleaned', # ToS breaking
707
- 'Hello-SimpleAI/HC3', # bad license
708
- 'glue', # no reasoning QA
709
- 'sahil2801/CodeAlpaca-20k', # bad license
710
- 'Short-Answer-Feedback/saf_communication_networks_english', # long Q, medium A
711
- ]
712
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if x.id not in exclude_ids]
713
- # some ids clearly speech related
714
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
715
- # HF testing
716
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
717
- 'hf-internal-testing' not in x.id]
718
- small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
719
- 'chinese' not in x.id]
720
-
721
- sorted_small_open_english_tasked_datasets = sorted([(x.downloads, x) for x in small_open_english_tasked_datasets],
722
- key=lambda x: x[0], reverse=True)
723
-
724
- # NOTES:
725
- # Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log
726
- # See what needs config passed and add:
727
- # grep 'load_dataset(' getdata9.log|grep -v data_id|less -S
728
- # grep "pip install" getdata9.log
729
- # NOTE: Some datasets have default config, but others are there. Don't know how to access them.
730
-
731
- """
732
- https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
733
- https://github.com/mahnazkoupaee/WikiHow-Dataset
734
- https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
735
- https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
736
- """
737
-
738
- """
739
- # some ambiguous or non-commercial datasets
740
- https://github.com/PhoebusSi/alpaca-CoT
741
- """
742
-
743
- timeout = 3 * 60
744
- # laion/OIG takes longer
745
- for num_downloads, dataset in sorted_small_open_english_tasked_datasets:
746
- data_id = dataset.id
747
- func = do_one
748
- args = (data_id, num_downloads)
749
- kwargs = {}
750
- with ProcessPoolExecutor(max_workers=1) as executor:
751
- future = executor.submit(func, *args, **kwargs)
752
- try:
753
- future.result(timeout=timeout)
754
- except concurrent.futures.TimeoutError:
755
- print("\n\ndata_id %s timeout\n\n" % data_id, flush=True)
756
- for child in psutil.Process(os.getpid()).children(recursive=True):
757
- os.kill(child.pid, signal.SIGINT)
758
- os.kill(child.pid, signal.SIGTERM)
759
- os.kill(child.pid, signal.SIGKILL)
760
-
761
-
762
- def do_one(data_id, num_downloads):
763
- from datasets import load_dataset
764
- out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
765
- if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024 ** 3:
766
- return
767
- try:
768
- print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
769
- avail_list = None
770
- try:
771
- data = load_dataset(data_id, 'foobar')
772
- except Exception as e:
773
- if 'Available: ' in str(e):
774
- avail_list = ast.literal_eval(str(e).split('Available:')[1].strip())
775
- else:
776
- avail_list = None
777
- if avail_list is None:
778
- avail_list = [None]
779
- print("%s avail_list: %s" % (data_id, avail_list), flush=True)
780
-
781
- for name in avail_list:
782
- out_file = "data_%s_%s.parquet" % (str(data_id.replace('/', '_')), str(name))
783
- if os.path.isfile(out_file):
784
- continue
785
- data = load_dataset(data_id, name)
786
- column_names_dict = data.column_names
787
- column_names = column_names_dict[list(column_names_dict.keys())[0]]
788
- print("Processing data_id %s num_downloads: %s columns: %s" % (data_id, num_downloads, column_names),
789
- flush=True)
790
- data_dict = data.data
791
- col_dict = data.num_columns
792
- first_col = list(col_dict.keys())[0]
793
- if 'train' in data_dict:
794
- df = data['train'].to_pandas()
795
- else:
796
- df = data[first_col].to_pandas()
797
- # csv has issues with escaping chars, even for datasets I know I want
798
- df.to_parquet(out_file, index=False)
799
- except Exception as e:
800
- t, v, tb = sys.exc_info()
801
- ex = ''.join(traceback.format_exception(t, v, tb))
802
- print("Exception: %s %s" % (data_id, ex), flush=True)
803
-
804
-
805
- def test_otherlic():
806
- from huggingface_hub import list_datasets
807
- lic = ['license:odc-by',
808
- 'license:cc-by-4.0',
809
- 'license:cc-by-3.0',
810
- 'license:cc-by-2.0',
811
- 'license:cc-by-2.5',
812
- 'license:cc-by-sa-4.0',
813
- 'license:odbl',
814
- 'license:pddl',
815
- 'license:ms-pl',
816
- 'license:zlib',
817
- ]
818
- datasets = flatten_list([[x for x in list_datasets(filter=y) if 'translation' not in str(x.tags)] for y in lic])
819
- print(len(datasets))
820
-
821
-
822
- # These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile
823
- # grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog
824
- useful = ['Dahoas/instruct-human-assistant-prompt',
825
- 'Dahoas/first-instruct-human-assistant-prompt',
826
- 'knkarthick/dialogsum', # summary of conversation
827
- 'McGill-NLP/FaithDial', # medium quality
828
- 'Zaid/quac_expanded', # medium quality context + QA
829
- '0-hero/OIG-small-chip2', # medium
830
- 'alistvt/coqa-flat', # QA medium
831
- 'AnonymousSub/MedQuAD_47441_Question_Answer_Pairs', # QA medium
832
- 'Anthropic/hh-rlhf', # high quality # similar to Dahoas/full-hh-rlhf
833
- 'arjunth2001/online_privacy_qna', # good quality QA
834
- 'Dahoas/instruct_helpful_preferences', # medium quality instruct
835
- 'Dahoas/rl-prompt-dataset', # medium chat
836
- 'Dahoas/rm-static', # medium chat
837
- 'Dahoas/static-hh', # medium chat # HuggingFaceH4/self_instruct
838
- 'Dahoas/synthetic-instruct-gptj-pairwise', # medium chat
839
- 'eli5', # QA if prompt ELI5
840
- 'gsm8k', # QA (various)
841
- 'guanaco/guanaco', # prompt/response
842
- 'kastan/rlhf-qa-comparisons', # good QA
843
- 'kastan/rlhf-qa-conditional-generation-v2', # prompt answer
844
- 'OllieStanley/humaneval-mbpp-codegen-qa', # code QA, but started from words, so better than other code QA
845
- 'OllieStanley/humaneval-mbpp-testgen-qa', # code QA
846
- 'Graverman/Instruct-to-Code', # code QA
847
- 'openai/summarize_from_feedback', # summarize
848
- 'relbert/analogy_questions', # analogy QA
849
- 'yitingxie/rlhf-reward-datasets', # prompt, chosen, rejected.
850
- 'yizhongw/self_instruct', # instruct (super natural & instruct)
851
- 'HuggingFaceH4/asss', # QA, big A
852
- 'kastan/rlhf-qa-conditional-generation-v2', # QA
853
- 'cosmos_qa', # context QA
854
- 'vishal-burman/c4-faqs', # QA but not so much reasoning, but alot of text
855
- 'squadshifts', # QA from context
856
- 'hotpot_qa', # QA from context
857
- 'adversarial_qa', # QA from context
858
- 'allenai/soda', # dialog -> narrative/summary
859
- 'squad_v2', # context QA
860
- 'squadshifts', # context QA
861
- 'dferndz/cSQuAD1', # context QA
862
- 'dferndz/cSQuAD2', # context QA
863
- 'din0s/msmarco-nlgen', # context QA
864
- 'domenicrosati/TruthfulQA', # common sense truthful QA -- trivia but good trivia
865
- 'hotpot_qa', # context, QA
866
- 'HuggingFaceH4/self-instruct-eval', # instruct QA, medium quality, some language reasoning
867
- 'kastan/EE_QA_for_RLHF', # context QA
868
- 'KK04/LogicInference_OA', # instruction logical QA
869
- 'lmqg/qa_squadshifts_synthetic', # context QA
870
- 'lmqg/qg_squad', # context QA
871
- 'lmqg/qg_squadshifts', # context QA
872
- 'lmqg/qg_subjqa', # context QA
873
- 'pszemraj/HC3-textgen-qa',
874
- # QA medium, has human responses -- humans tend to provide links instead of trying to answer
875
- 'pythonist/newdata', # long context, QA, brief A
876
- 'ropes', # long background, situation, question, A
877
- 'wikitablequestions', # table -> QA
878
- 'bigscience/p3', # context QA but short answers
879
- ]
880
-
881
- code_useful = ['0n1xus/codexglue',
882
- 'openai_humaneval',
883
- 'koutch/staqc',
884
- ]
885
-
886
- maybe_useful = ['AlekseyKorshuk/comedy-scripts',
887
- 'openbookqa', # hard to parse, low reasoning
888
- 'qed', # reasonable QA, but low reasoning
889
- 'selqa', # candidate answers
890
- 'HuggingFaceH4/instruction-pilot-outputs-filtered',
891
- 'GBaker/MedQA-USMLE-4-options', # medical QA with long questions
892
- 'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
893
- ]
894
-
895
- summary_useful = ['austin/rheum_abstracts',
896
- 'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
897
- 'CarperAI/openai_summarize_tldr', # summarize QA
898
- 'ccdv/cnn_dailymail', # summarize news
899
- 'ccdv/govreport-summarization', # summarize high quality
900
- 'ccdv/pubmed-summarization', # summarize high quality
901
- 'duorc', # plot -> QA
902
- 'farleyknight/big_patent_5_percent', # desc -> abstract
903
- 'multi_news', # summary
904
- 'opinosis',
905
- 'SophieTr/reddit_clean',
906
- 'allenai/mup', # long text -> summary
907
- 'allenai/multi_lexsum', # long text -> summary
908
- 'big_patent',
909
- 'allenai/wcep_dense_max',
910
- 'awinml/costco_long_practice',
911
- 'GEM/xsum',
912
- 'ratishsp/newshead',
913
- 'RussianNLP/wikiomnia', # russian
914
- 'stacked-summaries/stacked-xsum-1024',
915
- ]
916
-
917
- math_useful = [
918
- 'competition_math'
919
- ]
920
-
921
- skipped = ['c4', # maybe useful, used for flan, but skipped due to size
922
- ]
923
-
924
- """
925
- To get training data from oig:
926
- pytest test_oig test_grade_final test_finalize_to_json
927
- """
928
-
929
- human = '<human>:'
930
- bot = '<bot>:'
931
-
932
-
933
- def test_assemble_and_detox():
934
- import re
935
- from profanity_check import predict_prob
936
- df_list = []
937
- for data in useful_oig_files:
938
- print("Processing %s" % data, flush=True)
939
- df = pd.read_parquet(data)
940
- df = df.reset_index(drop=True)
941
- # chop up into human/bot interactions of no more than 10kB per row
942
- text_list = df[['text']].values.ravel().tolist()
943
- new_text = []
944
- max_len = 2048 # uber cutoff
945
- MAX_LEN = 2048 // 2 - 30 # max len per question/answer
946
- for text in tqdm(text_list):
947
- human_starts = [m.start() for m in re.finditer('<human>: ', text)]
948
- if len(human_starts) == 1:
949
- human_starts = [0, len(text)] # always go into for loop below
950
- blurb = ''
951
- for i in range(len(human_starts) - 1):
952
- interaction = text[human_starts[i]: human_starts[i + 1]][:max_len]
953
- blurb += interaction
954
- if len(blurb) >= MAX_LEN:
955
- blurb = get_sentences(blurb, length=MAX_LEN)[0]
956
- new_text.append(blurb + "\n<human>:")
957
- blurb = ''
958
- if blurb:
959
- blurb = get_sentences(blurb, length=MAX_LEN)[0]
960
- new_text.append(blurb + "\n<human>:")
961
-
962
- if len(new_text) > len(text_list):
963
- print("Added %d new rows (before: %d)" % (len(new_text) - df.shape[0], df.shape[0]))
964
- df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)})
965
- df = df.drop_duplicates(keep='first')
966
- print(df['text'].apply(lambda x: len(x)).describe())
967
- assert df['text'].apply(lambda x: len(x)).max() <= 2 * max_len
968
-
969
- # faster than better_profanity, do early
970
- df['profanity'] = predict_prob(df['text'])
971
- before_rows = df.shape[0]
972
- df = df[df['profanity'] < 0.25] # drop any low quality stuff
973
- after_rows = df.shape[0]
974
- print("Dropped %d rows out of %d due to alt-profanity-check" % (before_rows - after_rows, before_rows))
975
- df_list.append(df)
976
- print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
977
- print("So far have %d rows" % sum([len(x) for x in df_list]))
978
- df_final = pd.concat(df_list)
979
- df_final = df_final.sample(frac=1, random_state=1234).reset_index(drop=True)
980
- df_final.to_parquet('h2oGPT.cleaned.human_bot.shorter.parquet', index=False)
981
-
982
-
983
- def test_basic_cleaning():
984
- # from better_profanity import profanity
985
- # https://pypi.org/project/alt-profanity-check/
986
- from profanity_check import predict
987
- df_list = []
988
- for data in useful_oig_files:
989
- # for data in useful_oig_files[:5]:
990
- # for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
991
- print("Processing %s" % data, flush=True)
992
- df = pd.read_parquet(data)
993
- df = df.reset_index(drop=True)
994
- # NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
995
- # avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
996
- df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot)) / 2.0)
997
- df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
998
- # df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
999
- # low_quality_patterns = ['Write the rest of this wikipedia article']
1000
- res = predict(df['text'])
1001
- df['bad_words'] = res
1002
- df = df.reset_index(drop=True)
1003
- df = df[df['bad_words'] == 0]
1004
- df = df[['text', 'avg_words', 'avg_bot_words']]
1005
- df = df.drop_duplicates(keep='first')
1006
- print(df[df['avg_words'] == df['avg_words'].max()]['text'].values)
1007
- median_words = np.median(df['avg_words'])
1008
- min_words_per_entity = max(30, 0.8 * median_words)
1009
- max_words_per_entity = 2048 # too hard to learn from for now
1010
- df = df[df['avg_words'] > min_words_per_entity]
1011
- df = df[df['avg_words'] < max_words_per_entity]
1012
-
1013
- min_words_per_entity = max(20, 0.5 * median_words) # bot should say stuff for now
1014
- max_words_per_entity = 2048 # too hard to learn from for now
1015
- df = df[df['avg_bot_words'] > min_words_per_entity]
1016
- df = df[df['avg_bot_words'] < max_words_per_entity]
1017
-
1018
- df_list.append(df)
1019
- print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
1020
- df_final = pd.concat(df_list)
1021
- df_final.to_parquet('h2oGPT.cleaned.human_bot.parquet', index=False)
1022
-
1023
-
1024
- from joblib import Parallel, delayed, effective_n_jobs
1025
- from sklearn.utils import gen_even_slices
1026
- from sklearn.utils.validation import _num_samples
1027
-
1028
-
1029
- def parallel_apply(df, func, n_jobs=-1, **kwargs):
1030
- """ Pandas apply in parallel using joblib.
1031
- Uses sklearn.utils to partition input evenly.
1032
-
1033
- Args:
1034
- df: Pandas DataFrame, Series, or any other object that supports slicing and apply.
1035
- func: Callable to apply
1036
- n_jobs: Desired number of workers. Default value -1 means use all available cores.
1037
- **kwargs: Any additional parameters will be supplied to the apply function
1038
-
1039
- Returns:
1040
- Same as for normal Pandas DataFrame.apply()
1041
-
1042
- """
1043
-
1044
- if effective_n_jobs(n_jobs) == 1:
1045
- return df.apply(func, **kwargs)
1046
- else:
1047
- ret = Parallel(n_jobs=n_jobs)(
1048
- delayed(type(df).apply)(df[s], func, **kwargs)
1049
- for s in gen_even_slices(_num_samples(df), effective_n_jobs(n_jobs)))
1050
- return pd.concat(ret)
1051
-
1052
-
1053
- def add_better_profanity_flag(df):
1054
- from better_profanity import profanity
1055
- df['better_profanity'] = parallel_apply(
1056
- df['text'],
1057
- lambda x: profanity.contains_profanity(x),
1058
- n_jobs=-1,
1059
- )
1060
- return df
1061
-
1062
-
1063
- def add_textstat_grade(df):
1064
- import textstat
1065
-
1066
- def myfunc(x):
1067
- return textstat.flesch_kincaid_grade(x) # simple grade
1068
-
1069
- if False:
1070
- import dask.dataframe as dd
1071
- # 40 seconds for 1000 rows, but have 1,787,799 rows
1072
- ddata = dd.from_pandas(df, npartitions=120)
1073
-
1074
- df['flesch_grade'] = ddata['text'].apply(myfunc).compute()
1075
- if True:
1076
- # fast way
1077
- df['flesch_grade'] = parallel_apply(df['text'], myfunc, n_jobs=-1)
1078
- return df
1079
-
1080
-
1081
- def add_deberta_grade(df):
1082
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
1083
- import torch
1084
- reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
1085
- rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(
1086
- reward_name), AutoTokenizer.from_pretrained(reward_name)
1087
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
1088
- rank_model.to(device)
1089
-
1090
- def get_question(x):
1091
- return x.replace('<human>: ', '').split('<bot>:')[0]
1092
-
1093
- def get_answer(x):
1094
- try:
1095
- answer = x.split('<bot>: ')[1].split('<human>:')[0].replace('<bot>: ', '')
1096
- except:
1097
- answer = x.split('<bot>:')[1].split('<human>:')[0].replace('<bot>:', '')
1098
- return answer
1099
-
1100
- df['question'] = parallel_apply(df['text'], get_question, n_jobs=-1)
1101
- df['answer'] = parallel_apply(df['text'], get_answer, n_jobs=-1)
1102
-
1103
- from datasets import Dataset
1104
- from transformers import pipeline
1105
- from transformers.pipelines.pt_utils import KeyPairDataset
1106
- import tqdm
1107
-
1108
- pipe = pipeline(
1109
- "text-classification",
1110
- model=reward_name,
1111
- device="cuda:0" if torch.cuda.is_available() else "cpu"
1112
- )
1113
- start = 0
1114
- batch_size = 64 * 16
1115
- micro_batch = orig_micro_batch = 16
1116
- end = 0
1117
- import socket
1118
- checkpoint = "grades.%s.pkl" % socket.gethostname()
1119
- grades = []
1120
- import pickle
1121
- if os.path.exists(checkpoint):
1122
- with open(checkpoint, "rb") as f:
1123
- start, grades = pickle.loads(f.read())
1124
- last_oom = 0
1125
- while end < df.shape[0]:
1126
- # manual batching to handle OOM more gracefully
1127
- end = min(start + batch_size, df.shape[0])
1128
- if start == end:
1129
- break
1130
- dataset = Dataset.from_pandas(df.iloc[start:end, :])
1131
- try:
1132
- grades.extend([
1133
- x['score'] for x in tqdm.tqdm(
1134
- pipe(KeyPairDataset(dataset, "question", "answer"), batch_size=micro_batch)
1135
- )
1136
- ])
1137
- except torch.cuda.OutOfMemoryError:
1138
- last_oom = start
1139
- micro_batch = max(1, micro_batch // 2)
1140
- print("OOM - retrying with micro_batch=%d" % micro_batch)
1141
- continue
1142
- if last_oom == start:
1143
- micro_batch = orig_micro_batch
1144
- print("Returning to micro_batch=%d" % micro_batch)
1145
- assert len(grades) == end
1146
- start = end
1147
- with open(checkpoint, "wb") as f:
1148
- f.write(pickle.dumps((end, grades)))
1149
- print("%d/%d" % (end, df.shape[0]))
1150
- df['grade_deberta'] = grades
1151
- if os.path.exists(checkpoint):
1152
- os.remove(checkpoint)
1153
- return df
1154
-
1155
-
1156
- def test_chop_by_lengths():
1157
- file = "h2oGPT.cleaned.human_bot.shorter.parquet"
1158
- df = pd.read_parquet(file).reset_index(drop=True)
1159
- df = count_human_bot_lengths(df)
1160
- df['rand'] = np.random.rand(df.shape[0])
1161
- df['rand2'] = np.random.rand(df.shape[0])
1162
- before_rows = df.shape[0]
1163
- # throw away short human/bot responses with higher likelihood
1164
- df = df[(df['len_human_mean'] > 20)] # never keep very short ones
1165
- df = df[(df['len_human_mean'] > 30) | (df['rand'] < 0.2)]
1166
- df = df[(df['len_human_mean'] > 50) | (df['rand'] < 0.5)]
1167
- df = df[(df['len_human_max'] < 10000)] # drop super long (basically only human) ones
1168
- df = df[(df['len_bot_mean'] > 20)] # never keep very short ones
1169
- df = df[(df['len_bot_mean'] > 30) | (df['rand2'] < 0.2)]
1170
- df = df[(df['len_bot_mean'] > 50) | (df['rand2'] < 0.5)]
1171
- df = df[(df['len_bot_max'] < 10000)] # drop super long (only bot) ones
1172
- assert df['text'].apply(lambda x: len(x)).max() < 20000
1173
- df = df.drop(['rand', 'rand2'], axis=1)
1174
- after_rows = df.shape[0]
1175
- print("Chopped off %d out of %d rows due to length" % (before_rows - after_rows, before_rows))
1176
- print(df.describe())
1177
- df.to_parquet('h2oGPT.cleaned.chopped.human_bot.shorter.parquet', index=False)
1178
-
1179
-
1180
- def count_human_bot_lengths(df, human=None, bot=None):
1181
- import re
1182
- len_human_min = []
1183
- len_human_max = []
1184
- len_human_mean = []
1185
- len_bot_min = []
1186
- len_bot_max = []
1187
- len_bot_mean = []
1188
- human = human or '<human>:'
1189
- bot = bot or '<bot>:'
1190
- for is_human in [True, False]:
1191
- what = human if is_human else bot
1192
- other = human if not is_human else bot
1193
- for i in range(df.shape[0]):
1194
- text = df.loc[i, 'text']
1195
- assert isinstance(text, str)
1196
- starts = [m.start() for m in re.finditer(what, text)]
1197
- if len(starts) == 1:
1198
- starts = [starts[0], len(text)] # always go into for loop below
1199
- assert len(text)
1200
- list_what = []
1201
- for ii in range(len(starts) - 1):
1202
- interaction = text[starts[ii]: starts[ii + 1]]
1203
- if other in interaction:
1204
- interaction = interaction[:interaction.find(other)]
1205
- interaction.strip()
1206
- list_what.append(interaction)
1207
- if not list_what:
1208
- list_what = [''] # handle corrupted data, very rare, leads to sizes 0
1209
- if is_human:
1210
- len_human_min.append(min([len(x) for x in list_what]))
1211
- len_human_max.append(max([len(x) for x in list_what]))
1212
- len_human_mean.append(np.mean([len(x) for x in list_what]))
1213
- else:
1214
- len_bot_min.append(min([len(x) for x in list_what]))
1215
- len_bot_max.append(max([len(x) for x in list_what]))
1216
- len_bot_mean.append(np.mean([len(x) for x in list_what]))
1217
- df['len_human_min'] = len_human_min
1218
- df['len_human_max'] = len_human_max
1219
- df['len_human_mean'] = len_human_mean
1220
- df['len_bot_min'] = len_bot_min
1221
- df['len_bot_max'] = len_bot_max
1222
- df['len_bot_mean'] = len_bot_mean
1223
- np.random.seed(1234)
1224
- pd.set_option('display.max_columns', None)
1225
- print("Before chopping")
1226
- print(df.describe())
1227
- return df
1228
-
1229
-
1230
- def test_grade():
1231
- df = None
1232
-
1233
- file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet"
1234
- output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet"
1235
- if not os.path.exists(output_file):
1236
- if df is None:
1237
- df = pd.read_parquet(file).reset_index(drop=True)
1238
- df = add_textstat_grade(df)
1239
- min_grade = 10
1240
- max_grade = 25
1241
- df = df[df['flesch_grade'] >= min_grade]
1242
- df = df[df['flesch_grade'] <= max_grade]
1243
- print("After Flesch grade")
1244
- print(df.describe())
1245
- df.to_parquet(output_file, index=False)
1246
-
1247
- file = output_file
1248
- output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet"
1249
- if not os.path.exists(output_file):
1250
- # slower than alt-profanity, do last, but do before deberta grading, since that's slower
1251
- if df is None:
1252
- df = pd.read_parquet(file).reset_index(drop=True)
1253
- df = add_better_profanity_flag(df)
1254
- before_rows = df.shape[0]
1255
- df = df[df['better_profanity'] == 0]
1256
- df = df.drop(['better_profanity'], axis=1)
1257
- after_rows = df.shape[0]
1258
- print("Dropped %d rows out of %d due to better_profanity" % (before_rows - after_rows, before_rows))
1259
- print(df.describe())
1260
- df.to_parquet(output_file, index=False)
1261
-
1262
- file = output_file
1263
- output_file = 'h2oGPT.cleaned.graded3.human_bot.shorter.parquet'
1264
- if not os.path.exists(output_file):
1265
- if df is None:
1266
- df = pd.read_parquet(file).reset_index(drop=True)
1267
- df = add_deberta_grade(df)
1268
- min_grade = 0.3
1269
- max_grade = np.inf
1270
- before_rows = df.shape[0]
1271
- df = df[df['grade_deberta'] >= min_grade]
1272
- df = df[df['grade_deberta'] <= max_grade]
1273
- after_rows = df.shape[0]
1274
- print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
1275
- print("After DeBERTa grade")
1276
- print(df.describe())
1277
- df.to_parquet(output_file, index=False)
1278
-
1279
- file = output_file
1280
- output_file = 'h2oGPT.cleaned.graded.human_bot.shorter.parquet'
1281
- if df is None:
1282
- df = pd.read_parquet(file).reset_index(drop=True)
1283
- df.to_parquet(output_file, index=False)
1284
-
1285
-
1286
- @pytest.mark.parametrize(
1287
- "fixup_personality, only_personality, deberta_grading",
1288
- [
1289
- [False, False, False],
1290
- [True, True, False],
1291
- [True, False, False],
1292
- [True, False, True],
1293
- ]
1294
- )
1295
- def test_add_open_assistant(fixup_personality, only_personality, deberta_grading, save_json=True):
1296
- """
1297
- Flatten tree structure into one row per path from root to leaf
1298
- Also turn into human_bot prompting format:
1299
- <human>: question\n<bot>: answer <human>: question2\n<bot>: answer2 Etc.
1300
- Also saves a .json locally as side-effect
1301
- returns list of dicts, containing intput, prompt_type and source
1302
- """
1303
- from datasets import load_dataset
1304
- data_file = "OpenAssistant/oasst1"
1305
- ds = load_dataset(data_file)
1306
- df = pd.concat([ds['train'].to_pandas(), ds['validation'].to_pandas()], axis=0)
1307
- rows = {}
1308
- message_ids = df['message_id'].values.tolist()
1309
- message_tree_ids = df['message_tree_id'].values.tolist()
1310
- parent_ids = df['parent_id'].values.tolist()
1311
- texts = df['text'].values.tolist()
1312
- roles = df['role'].values.tolist()
1313
-
1314
- for i in range(df.shape[0]):
1315
- # collect all trees
1316
- message_id = message_ids[i]
1317
- message_tree_id = message_tree_ids[i]
1318
- parent_id = parent_ids[i]
1319
- text = texts[i]
1320
- if fixup_personality:
1321
- text = text.replace("Open Assistant", "h2oGPT")
1322
- text = text.replace("Open-Assistant", "h2oGPT")
1323
- text = text.replace("open-assistant", "h2oGPT")
1324
- text = text.replace("OpenAssistant", "h2oGPT")
1325
- text = text.replace("open assistant", "h2oGPT")
1326
- text = text.replace("Open Assistand", "h2oGPT")
1327
- text = text.replace("Open Assitant", "h2oGPT")
1328
- text = text.replace("Open Assistent", "h2oGPT")
1329
- text = text.replace("Open Assisstant", "h2oGPT")
1330
- text = text.replace("Open Assitent", "h2oGPT")
1331
- text = text.replace("Open Assitiant", "h2oGPT")
1332
- text = text.replace("Open Assistiant", "h2oGPT")
1333
- text = text.replace("Open Assitan ", "h2oGPT ")
1334
- text = text.replace("Open Assistan ", "h2oGPT ")
1335
- text = text.replace("Open Asistant", "h2oGPT")
1336
- text = text.replace("Open Assiant", "h2oGPT")
1337
- text = text.replace("Assistant", "h2oGPT")
1338
- text = text.replace("LAION AI", "H2O.ai")
1339
- text = text.replace("LAION-AI", "H2O.ai")
1340
- text = text.replace("LAION,", "H2O.ai,")
1341
- text = text.replace("LAION.ai", "H2O.ai")
1342
- text = text.replace("LAION.", "H2O.ai.")
1343
- text = text.replace("LAION", "H2O.ai")
1344
-
1345
- role = roles[i]
1346
- new_data = ('<human>: ' if role == 'prompter' else '<bot>: ') + text
1347
- entry = dict(message_id=message_id, parent_id=parent_id, text=new_data)
1348
- if message_tree_id not in rows:
1349
- rows[message_tree_id] = [entry]
1350
- else:
1351
- rows[message_tree_id].append(entry)
1352
-
1353
- all_rows = []
1354
-
1355
- for node_id in rows:
1356
- # order responses in tree, based on message/parent relationship
1357
- conversations = []
1358
-
1359
- list_msgs = rows[node_id]
1360
- # find start
1361
- while len(list_msgs):
1362
- for i, leaf in enumerate(list_msgs):
1363
- found = False
1364
- parent_id = leaf['parent_id']
1365
- if parent_id is None:
1366
- # conversation starter
1367
- conversations.append(leaf)
1368
- found = True
1369
- else:
1370
- for conv in conversations:
1371
- # find all conversations to add my message to
1372
- if parent_id in conv['message_id'] and parent_id != conv['message_id'][-len(parent_id):]:
1373
- # my message doesn't follow conversation
1374
- continue
1375
- if parent_id == conv['message_id'][-len(parent_id):]:
1376
- # my message follows conversation, but fork first, so another follow-on message can do same
1377
- conversations.append(conv.copy())
1378
- conv['text'] += f"""
1379
- {leaf['text']}
1380
- """
1381
- conv['message_id'] += leaf['message_id']
1382
- found = True
1383
- break
1384
- if found:
1385
- # my content was used, so nuke from list
1386
- del list_msgs[i]
1387
- break
1388
-
1389
- # now reduce down to final conversations, find the longest chains of message ids
1390
- for i, conv in enumerate(conversations):
1391
- for j, conv2 in enumerate(conversations):
1392
- if i == j:
1393
- continue
1394
- if conv['message_id'] and conv2['message_id']:
1395
- assert conv['message_id'] != conv2['message_id']
1396
- # delete the shorter conversation, if one contains the other
1397
- if conv['message_id'] in conv2['message_id']:
1398
- conv['message_id'] = None
1399
- if conv2['message_id'] in conv['message_id']:
1400
- conv2['message_id'] = None
1401
- conversations = [c for c in conversations if c['message_id']]
1402
- if only_personality:
1403
- all_rows.extend(
1404
- [dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
1405
- 'h2oGPT' in c['text']])
1406
- else:
1407
- all_rows.extend(
1408
- [dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if
1409
- "What is H2O.ai" not in c['text']])
1410
- unhelpful = get_unhelpful_list()
1411
- all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
1412
- personality = create_personality_data()
1413
- all_rows.extend(personality * 10)
1414
- np.random.seed(123)
1415
- np.random.shuffle(all_rows)
1416
- print(len(all_rows))
1417
- if deberta_grading:
1418
- df = pd.DataFrame(all_rows)
1419
- df = df.rename(columns={'input': 'text'})
1420
- df = add_deberta_grade(df)
1421
- df = df.rename(columns={'text': 'input'})
1422
- drop = True
1423
- if drop:
1424
- min_grade = 0.3
1425
- max_grade = np.inf
1426
- before_rows = df.shape[0]
1427
- df = df[df['grade_deberta'] >= min_grade]
1428
- df = df[df['grade_deberta'] <= max_grade]
1429
- after_rows = df.shape[0]
1430
- print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
1431
- print("After DeBERTa grade")
1432
- print(df.describe())
1433
- all_rows = []
1434
- for i in range(df.shape[0]):
1435
- all_rows.append(
1436
- dict(
1437
- input=df['input'].iloc[i],
1438
- source=df['source'].iloc[i],
1439
- prompt_type=df['prompt_type'].iloc[i],
1440
- grade_deberta=df['grade_deberta'].iloc[i],
1441
- )
1442
- )
1443
- if save_json:
1444
- data_file = data_file + \
1445
- ("_h2ogpt" if fixup_personality else "") + \
1446
- ("_only" if only_personality else "") + \
1447
- ("_graded" if deberta_grading else "")
1448
- for i in range(len(all_rows)):
1449
- all_rows[i]['id'] = i
1450
- with open(data_file.lower().replace("/", "_") + ".json", "w") as f:
1451
- f.write(json.dumps(all_rows, indent=2))
1452
- return all_rows
1453
-
1454
-
1455
- def test_finalize_to_json():
1456
- df = pd.read_parquet('h2oGPT.cleaned.graded.human_bot.shorter.parquet')
1457
- df = df.rename(columns={'text': 'input'})
1458
-
1459
- print("Number of high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1460
-
1461
- print("Adding open assistant data")
1462
- with open("openassistant_oasst1_h2ogpt_graded.json") as f:
1463
- open_assistant = json.loads(f.read())
1464
- df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0)
1465
-
1466
- def final_clean(df):
1467
- from better_profanity import profanity
1468
- profanity.load_censor_words_from_file("data/censor_words.txt")
1469
- df['profanity'] = parallel_apply(
1470
- df['input'],
1471
- lambda x: profanity.contains_profanity(x),
1472
- n_jobs=-1,
1473
- )
1474
- return df[(df['profanity'] == 0)].reset_index(drop=True)
1475
-
1476
- print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1477
- df = final_clean(df)
1478
- print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1479
- print(df.describe())
1480
- print(df.shape)
1481
- row_list = []
1482
- for i in range(df.shape[0]):
1483
- row_list.append(
1484
- dict(
1485
- input=df.loc[i, 'input'],
1486
- source=df.loc[i, 'source'],
1487
- prompt_type='plain',
1488
- )
1489
- )
1490
- np.random.seed(1234)
1491
- np.random.shuffle(row_list)
1492
- unhelpful = get_unhelpful_list()
1493
- row_list = [x for x in row_list if not any(u in x['input'] for u in unhelpful)]
1494
- for i in range(len(row_list)):
1495
- row_list[i]['id'] = i
1496
- row_list[i]['input'] = row_list[i]['input'].replace(" <bot>:", "\n<bot>:")
1497
- with open('h2ogpt-oig-oasst1-instruct-cleaned-v3.json', "w") as f:
1498
- f.write(json.dumps(row_list, indent=2))
1499
-
1500
-
1501
- def create_personality_data():
1502
- questions = [
1503
- "What's your name?",
1504
- "What is your name?",
1505
- "What are you?",
1506
- "Who are you?",
1507
- "Do you have a name?",
1508
- "Who trained you?",
1509
- "Who created you?",
1510
- "Who made you?",
1511
- ]
1512
- answers = [
1513
- "I'm h2oGPT, a large language model by H2O.ai.",
1514
- "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
1515
- "My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.",
1516
- "My name is h2oGPT. I'm a large language model trained by H2O.ai.",
1517
- "Hi! I'm h2oGPT, a large language model by H2O.ai.",
1518
- "Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
1519
- ]
1520
- help = [
1521
- "",
1522
- " How can I help you?",
1523
- " How may I assist you?",
1524
- " Nice to meet you.",
1525
- ]
1526
- import itertools
1527
- rows = []
1528
- for pair in itertools.product(questions, answers, help):
1529
- rows.append(
1530
- dict(input=f"<human>: {pair[0]}\n<bot>: {pair[1]}{pair[2]}\n<human>:", prompt_type='plain', source="H2O.ai")
1531
- )
1532
- for row in [
1533
- "<human>: What is H2O.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1534
- "<human>: What is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1535
- "<human>: What is H2O?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1536
- "<human>: Who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1537
- "<human>: who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1538
- "<human>: who is h2o?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1539
- "<human>: What is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1540
- "<human>: Who is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1541
- "<human>: Who is H2O?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1542
- "<human>: Who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1543
- "<human>: who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1544
- ]:
1545
- rows.append(dict(input=row, prompt_type='plain', source='H2O.ai'))
1546
- print(len(rows))
1547
- with open("h2ogpt-personality.json", "w") as f:
1548
- f.write(json.dumps(rows, indent=2))
1549
- return rows
1550
-
1551
-
1552
- def test_check_stats_data():
1553
- filename = 'h2ogpt-oig-oasst1-instruct-cleaned-v3.json'
1554
- df = pd.read_json(filename)
1555
-
1556
- # get word stats
1557
- df['char_count'] = df['input'].apply(lambda x: len(x))
1558
- import matplotlib.pyplot as plt
1559
- plt.figure(figsize=(10, 10))
1560
- plt.hist(df['char_count'], bins=100)
1561
- chars_avg = np.mean(df['char_count'])
1562
- chars_median = np.median(df['char_count'])
1563
- plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median))
1564
- plt.savefig('chars_hist.png')
1565
- plt.close()
1566
-
1567
- # get tokenize stats for random sample of 1000 rows
1568
- from finetune import generate_and_tokenize_prompt
1569
- from loaders import get_loaders, get_tokenizer
1570
- from functools import partial
1571
-
1572
- llama_type = False
1573
- tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1574
- model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
1575
- local_files_only = False
1576
- resume_download = True
1577
- use_auth_token = False
1578
- tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
1579
- prompt_type = 'plain' # trained with data already in human bot form
1580
- train_on_inputs = True
1581
- add_eos_token = False
1582
- cutoff_len = 512 # can choose 2048
1583
- generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
1584
- train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
1585
- cutoff_len=cutoff_len, tokenizer=tokenizer)
1586
- from datasets import load_dataset
1587
- data = load_dataset("json", data_files={"train": filename})
1588
- val_set_size = 0.90
1589
- train_val = data["train"].train_test_split(
1590
- test_size=val_set_size, shuffle=True, seed=42
1591
- )
1592
- train_data = train_val["train"]
1593
- train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count())
1594
-
1595
- df_tokens = pd.DataFrame([len(x) for x in train_data['input_ids']], columns=['token_count'])
1596
-
1597
- plt.figure(figsize=(10, 10))
1598
- plt.hist(df_tokens['token_count'], bins=100)
1599
- token_avg = np.mean(df_tokens['token_count'])
1600
- token_median = np.median(df_tokens['token_count'])
1601
- plt.title("token_count with cutoff=%s avg: %s median: %s" % (cutoff_len, token_avg, token_median))
1602
- plt.savefig('token_hist_%s.png' % cutoff_len)
1603
- plt.close()
1604
-
1605
-
1606
- def get_unhelpful_list():
1607
- # base versions
1608
- unhelpful = ["I'm sorry, I didn't quite understand your question, could you please rephrase it?",
1609
- "I'm sorry, but I don't understand your question. Could you please rephrase it?",
1610
- "I'm sorry, I don't quite understand your question",
1611
- "I'm sorry, I don't know",
1612
- "I'm sorry, but I don't know",
1613
- "I don't know anything",
1614
- "I do not know",
1615
- "I don't know",
1616
- "I don't know how",
1617
- "I do not know how",
1618
- "Can you please explain what you mean",
1619
- "please explain what you mean",
1620
- "please explain",
1621
- "I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by",
1622
- "I'm sorry but I don't understand what you mean",
1623
- "I don't understand",
1624
- "I don't have the ability",
1625
- "I do not have the ability",
1626
- "I do not have",
1627
- "I am a language model,",
1628
- "I am a large language model,",
1629
- "I do not understand your question. Can you please try to make it clearer?",
1630
- "I'm sorry, but as an AI language model",
1631
- "I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.",
1632
- "I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?",
1633
- "Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t",
1634
- "I apologize, but I cannot perform the task you have requested.",
1635
- "I'm sorry, I cannot perform this task as I am an AI language model and do not have access",
1636
- "I'm sorry, I'm not sure what you're asking for here.",
1637
- "I'm not sure what you are asking",
1638
- "You need to provide more context",
1639
- ]
1640
- # reduced versions, with redundant parts, just to give context for where they came from
1641
- unhelpful += ["sorry, I didn't quite understand your question",
1642
- "I didn't quite understand your question",
1643
- "I didn't understand your question",
1644
- "I did not understand your question",
1645
- "I did not understand the question",
1646
- "could you please rephrase"
1647
- "could you rephrase"
1648
- "I do not understand your question.",
1649
- "I do not understand the question.",
1650
- "I do not understand that question.",
1651
- "Can you please try to make it clearer",
1652
- "Can you try to make it clearer",
1653
- "sorry, but as an AI language model",
1654
- "as an AI language model",
1655
- "I apologize, but I cannot",
1656
- "I cannot rephrase text",
1657
- "I cannot understand. Your post is difficult to read and follow."
1658
- "Your post is difficult to read and follow."
1659
- "I apologize, but I am",
1660
- "Sorry, but I am not ",
1661
- "nor am I capable",
1662
- "I am not capable of",
1663
- "I apologize, but I cannot perform the task you have requested",
1664
- "I cannot perform the task",
1665
- "I cannot complete the task",
1666
- "I'm sorry",
1667
- "I am sorry",
1668
- "do not have access",
1669
- "not sure what you're asking for",
1670
- "not sure what you are asking for",
1671
- "not sure what is being asked",
1672
- "I'm not sure what you are asking",
1673
- "not sure what you are asking",
1674
- "You need to provide more context",
1675
- "provide more context",
1676
- ]
1677
- unhelpful += ["As a large language model",
1678
- "cannot provide any information",
1679
- "As an artificial intelligence I do not have the capability",
1680
- "As an artificial intelligence I don't have the capability",
1681
- "As an artificial intelligence I can't",
1682
- "As an artificial intelligence I cannot",
1683
- "I am sorry but I do not understand",
1684
- "Can you please explain",
1685
- "(sorry couldn't resist)",
1686
- "(sorry could not resist)",
1687
- " :)",
1688
- " ;)",
1689
- " :-)",
1690
- " ;-)",
1691
- " lol ",
1692
- "Thanks so much!!!",
1693
- "Thank You :)!!!",
1694
- "Please try not to repeat",
1695
- "I am an AI language model",
1696
- "I'm a AI assistant that",
1697
- "I'm an AI assistant that",
1698
- "I am an AI assistant that",
1699
- "etc.",
1700
- "etc.etc.",
1701
- "etc. etc.",
1702
- "etc etc",
1703
- ]
1704
- return unhelpful
1705
-
1706
-
1707
- def test_check_unhelpful():
1708
- # file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json'
1709
- file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json'
1710
- # file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
1711
-
1712
- unhelpful = get_unhelpful_list()
1713
- # data = json.load(open(file, 'rt'))
1714
- df = pd.read_json(file)
1715
-
1716
- use_reward_score_threshold = False
1717
- use_bleu_threshold = False
1718
- use_sentence_sim = True
1719
-
1720
- from sacrebleu.metrics import BLEU
1721
- bleu = BLEU()
1722
- from nltk.translate.bleu_score import sentence_bleu
1723
-
1724
- def get_bleu(actual, expected_list):
1725
- # return bleu.sentence_score(actual, expected_list).score
1726
- return sentence_bleu(expected_list, actual)
1727
-
1728
- threshold = 0.0
1729
- if use_reward_score_threshold:
1730
- df = df[df['grade_deberta'] > threshold]
1731
-
1732
- # back to as if original json load
1733
- data = df.to_dict(orient='records')
1734
- bads = {}
1735
- string_all = str(data)
1736
- for sub in unhelpful:
1737
- bads[sub] = string_all.count(sub)
1738
- bads = {k: v for k, v in bads.items() if v > 0}
1739
- import pprint
1740
- pp = pprint.PrettyPrinter(indent=4)
1741
- pp.pprint(bads)
1742
-
1743
- total_bads = sum(list(bads.values()))
1744
- print('total_bads: %s' % total_bads, flush=True)
1745
-
1746
- # check just bot
1747
- import re
1748
- convs = [[x.strip() for x in re.split(r'%s|%s' % (human, bot), y['input']) if x.strip()] for y in data]
1749
- humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs]
1750
- bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs]
1751
-
1752
- # FIXME: apply back to json etc., just see for now
1753
- bleu_threshold = 0.9
1754
- if use_bleu_threshold:
1755
- bots = [[x for x in y if get_bleu(x, unhelpful) < bleu_threshold] for y in tqdm(bots)]
1756
-
1757
- cosine_sim_threshold = 0.8
1758
- if use_sentence_sim:
1759
- # pip install sentence_transformers-2.2.2
1760
- from sentence_transformers import SentenceTransformer
1761
- # sent_model = 'bert-base-nli-mean-tokens'
1762
- # sent_model = 'nli-distilroberta-base-v2'
1763
- sent_model = 'all-MiniLM-L6-v2'
1764
- model = SentenceTransformer(sent_model)
1765
- sentence_embeddings = model.encode(unhelpful)
1766
- from sklearn.metrics.pairwise import cosine_similarity
1767
- bots = [x for x in tqdm(bots) if
1768
- np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
1769
-
1770
- bads_bots = {}
1771
- string_all = str(bots)
1772
- for sub in unhelpful:
1773
- bads_bots[sub] = string_all.count(sub)
1774
- bads_bots = {k: v for k, v in bads_bots.items() if v > 0}
1775
- import pprint
1776
- pp = pprint.PrettyPrinter(indent=4)
1777
- pp.pprint(bads_bots)
1778
-
1779
- total_bads_bots = sum(list(bads_bots.values()))
1780
- print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (
1781
- threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
1782
-
1783
- # assert len(bads) == 0, bads
1784
- assert len(bads_bots) == 0, bads_bots
1785
-
1786
-
1787
- def test_fortune2000_personalized():
1788
- row_list = []
1789
- import glob
1790
- if not os.path.isdir("wikitext"):
1791
- raise RuntimeError("download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip")
1792
- for file in glob.glob("wikitext/*.txt"):
1793
- with open(file, "r") as f:
1794
- blob = f.read()
1795
- N = 512 * 4
1796
- row_list.extend([{'input': s, 'prompt_type': 'plain', 'source': "%s" % os.path.basename(file)}
1797
- for s in get_sentences(blob, N) if s])
1798
- personality = create_personality_data()
1799
- import copy
1800
- for i in range(10):
1801
- row_list.extend(copy.deepcopy(personality))
1802
- np.random.seed(123)
1803
- np.random.shuffle(row_list)
1804
- for i in range(len(row_list)):
1805
- row_list[i]['id'] = i
1806
- for i in range(len(row_list)):
1807
- assert row_list[i]['id'] == i
1808
- with open("h2ogpt-fortune2000-personalized.json", "w") as ff:
1809
- ff.write(json.dumps(row_list, indent=2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
create_data.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../create_data.py
enums.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../enums.py
finetune.py DELETED
@@ -1,670 +0,0 @@
1
- import os
2
- import sys
3
- from functools import partial
4
- from typing import List, Union
5
- import fire
6
- import numpy as np
7
-
8
- from loaders import get_loaders, get_tokenizer
9
- from prompter import generate_prompt, prompt_types
10
- from utils import get_githash, copy_code
11
- import torch
12
-
13
-
14
- def log(*args, **kwargs):
15
- if int(os.environ.get("LOCAL_RANK", 0)) == 0:
16
- if 'flush' not in kwargs:
17
- kwargs['flush'] = True
18
- print(*args, **kwargs)
19
-
20
-
21
- # supported by huggingface evaluate
22
- supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
23
-
24
-
25
- def train(
26
- save_code: bool = False,
27
- run_id: int = None,
28
-
29
- base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
30
- # base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
31
- # base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
32
- # base_model: str = 'EleutherAI/gpt-neox-20b',
33
- # base_model: str = 'EleutherAI/pythia-12b-deduped',
34
- # base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
35
- # base_model: str = 'decapoda-research/llama-7b-hf',
36
- # base_model: str = 'decapoda-research/llama-13b-hf',
37
- # base_model: str = 'decapoda-research/llama-30b-hf',
38
- # base_model: str = 'EleutherAI/gpt-j-6B',
39
-
40
- # only needed if base_model is self-exported HF state without tokenizer
41
- tokenizer_base_model: str = None,
42
- # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
43
-
44
- data_path: str = "h2oai/openassistant_oasst1_h2ogpt",
45
- data_col_dict: dict = None,
46
- # data_path: str = "./dai_docs.train.json",
47
- prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
48
-
49
- valid_path: str = None,
50
- # valid_path: str = "./dai_docs.valid.json",
51
-
52
- # data_mix_in_path: str = "laion/OIG", # way too big, medium quality
53
- data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
54
- data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
55
- data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
56
- data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
57
-
58
- output_dir: str = None,
59
-
60
- # LoRA checkpoint continuation
61
- lora_weights: str = "",
62
-
63
- # batching training hyperparams
64
- batch_size: int = 128,
65
- micro_batch_size: int = 4,
66
- gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
67
- fp16=True,
68
- train_8bit=False,
69
- train_4bit=False,
70
-
71
- # general training hyperparams
72
- num_epochs: float = 1,
73
- learning_rate: float = 3e-4,
74
-
75
- # validation settings
76
- val_set_size: int = None,
77
- val_metrics: List[str] = [],
78
- eval_steps: int = None, # to control eval steps via steps
79
- eval_epochs: float = None, # to control eval steps via epochs
80
-
81
- # lora hyperparams
82
- lora_r: int = 8,
83
- lora_alpha: int = 16,
84
- lora_dropout: float = 0.05,
85
- lora_target_modules: List[str] = None,
86
- llama_type: bool = None,
87
- llama_flash_attn: bool = False,
88
-
89
- # llm hyperparams
90
- train_on_inputs: bool = True, # if False, masks out inputs in loss
91
- group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
92
- resume_from_checkpoint: str = None, # either training checkpoint or final adapter
93
- cutoff_len: int = 512, # larger values use more memory
94
- drop_truncations: bool = False, # if True, drop any truncated long sequences
95
-
96
- # torch training params
97
- ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
98
- local_files_only: bool = False, # else will download new versions, normally unwanted
99
- resume_download: bool = True,
100
- use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
101
- warmup_steps: int = 100,
102
- logging_steps: int = 1,
103
- save_steps: int = None, # must be round multiple of eval_steps
104
- save_total_limit: int = 3,
105
- add_eos_token: bool = False,
106
- ):
107
-
108
- if llama_flash_attn:
109
- # Need to call this before importing transformers.
110
- from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
111
- replace_llama_attn_with_flash_attn()
112
-
113
- # allow set token directly
114
- use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
115
-
116
- prompt_type = str(prompt_type) # migration from integers
117
- assert prompt_type in prompt_types
118
-
119
- world_size = int(os.getenv("WORLD_SIZE", 1))
120
- local_rank = int(os.getenv("LOCAL_RANK", 0))
121
- rank = int(os.getenv("RANK", 0))
122
- print(f"local_rank: {local_rank}")
123
- print(f"global rank: {rank}")
124
-
125
- gpus = max(world_size, torch.cuda.device_count())
126
- run_id = run_id or 0
127
- if not data_path:
128
- raise ValueError("No data_path provided")
129
- if not output_dir:
130
- output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
131
- if os.path.exists(output_dir) and not resume_from_checkpoint:
132
- raise FileExistsError(f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
133
- else:
134
- if os.path.exists(output_dir) and not resume_from_checkpoint:
135
- raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
136
- device_map = "auto"
137
-
138
- if save_code:
139
- copy_code(run_id)
140
- if tokenizer_base_model is None:
141
- tokenizer_base_model = base_model
142
- if llama_type is None:
143
- llama_type = "llama" in base_model.lower()
144
- if llama_type and llama_flash_attn:
145
- import pkg_resources
146
- try:
147
- pkg_resources.get_distribution('flash_attn')
148
- can_do_flash_attn = True
149
- except (pkg_resources.DistributionNotFound, pkg_resources.ContextualVersionConflict):
150
- can_do_flash_attn = False
151
-
152
- if not can_do_flash_attn:
153
- raise RuntimeError("""Flash attention not installed.
154
- NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
155
-
156
- CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
157
- assert (
158
- base_model
159
- ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
160
- gradient_accumulation_steps = batch_size // micro_batch_size
161
- assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
162
-
163
- device_map = "auto"
164
-
165
- locals_dict = locals()
166
- locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
167
- log(f"Training model with params:\n{locals_print}")
168
- log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
169
-
170
- max_memory = None
171
- if gpus > 1:
172
- if ddp:
173
- log("Distributed: data parallel")
174
- device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
175
- gradient_accumulation_steps = gradient_accumulation_steps // world_size
176
- else:
177
- free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
178
- max_memory = f"{free_in_GB - 2}GB"
179
- max_memory = {i: max_memory for i in range(gpus)}
180
- log("world_size: %d" % world_size)
181
- log("num_gpus: %d" % gpus)
182
- log("max mem: %s" % max_memory)
183
-
184
- model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
185
-
186
- model = model_loader.from_pretrained(
187
- base_model,
188
- load_in_8bit=train_8bit,
189
- load_in_4bit=train_4bit,
190
- device_map=device_map,
191
- torch_dtype=torch.float16,
192
- max_memory=max_memory,
193
- local_files_only=local_files_only,
194
- trust_remote_code=True,
195
- resume_download=resume_download,
196
- use_auth_token=use_auth_token,
197
- )
198
- if gpus > 1:
199
- if not ddp:
200
- log("model parallel")
201
- model.is_parallelizable = True
202
- model.model_parallel = True
203
-
204
- tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
205
-
206
- if train_8bit or train_4bit:
207
- from peft import (
208
- prepare_model_for_kbit_training,
209
- )
210
-
211
- model = prepare_model_for_kbit_training(model)
212
-
213
- from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
214
- try:
215
- from peft import utils
216
- lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
217
- except AttributeError:
218
- from peft import mapping
219
- lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
220
- lora_mappings['distilgpt2'] = ["c_attn"]
221
-
222
- if lora_weights:
223
-
224
- from peft import PeftModel
225
- model = PeftModel.from_pretrained(
226
- model,
227
- lora_weights,
228
- torch_dtype=torch.float16,
229
- device_map=device_map,
230
- local_files_only=local_files_only,
231
- resume_download=resume_download,
232
- use_auth_token=use_auth_token,
233
- )
234
- elif lora_r > 0:
235
- if lora_target_modules is None:
236
- base_model_lower = base_model.lower()
237
- if base_model_lower in lora_mappings:
238
- lora_target_modules_cand = [lora_mappings[base_model_lower]]
239
- else:
240
- lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
241
- else:
242
- lora_target_modules_cand = [lora_target_modules]
243
-
244
- for lora_target_modules in lora_target_modules_cand:
245
- try:
246
- config = LoraConfig(
247
- r=lora_r,
248
- lora_alpha=lora_alpha,
249
- target_modules=lora_target_modules,
250
- lora_dropout=lora_dropout,
251
- bias="none",
252
- task_type="CAUSAL_LM",
253
- )
254
- model = get_peft_model(model, config)
255
- break
256
- except ValueError as e:
257
- if "Target modules" in str(e) and "not found" in str(e):
258
- continue
259
- else:
260
- raise
261
- from peft import PeftModel
262
- assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
263
- if resume_from_checkpoint:
264
- # Check the available weights and load them
265
- checkpoint_name = os.path.join(
266
- resume_from_checkpoint, "pytorch_model.bin"
267
- ) # Full checkpoint
268
- if not os.path.exists(checkpoint_name):
269
- checkpoint_name = os.path.join(
270
- resume_from_checkpoint, "adapter_model.bin"
271
- ) # only LoRA model - LoRA config above has to fit
272
- resume_from_checkpoint = False # So the trainer won't try loading its state
273
- # The two files above have a different name depending on how they were saved, but are actually the same.
274
- if os.path.exists(checkpoint_name):
275
- log(f"Restarting from {checkpoint_name}")
276
- adapters_weights = torch.load(checkpoint_name)
277
- set_peft_model_state_dict(model, adapters_weights)
278
- else:
279
- log(f"Checkpoint {checkpoint_name} not found")
280
-
281
- print(model)
282
- try:
283
- # only for PeftModel
284
- model.print_trainable_parameters() # Be more transparent about the % of trainable params.
285
- except:
286
- pass
287
-
288
- metrics = {}
289
- for name in supported_metrics:
290
- if name in val_metrics:
291
- import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
292
- metrics[name] = evaluate.load(name)
293
- log("Using Validation Metrics: %s" % str(list(metrics.keys())))
294
- log("Supported Metrics: %s" % supported_metrics)
295
-
296
- if val_set_size is None:
297
- if len(metrics) == 0:
298
- val_set_size = 1000
299
- else:
300
- val_set_size = 100
301
- log("Auto set val_set_size %s" % val_set_size)
302
- elif val_set_size < 1.0 and val_set_size != 0:
303
- raise RuntimeError("Fractional validation size not supported.")
304
-
305
- from datasets import load_dataset, concatenate_datasets
306
- if valid_path:
307
- data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
308
- else:
309
- if "json" in data_path:
310
- data = load_dataset("json", data_files={"train": data_path})
311
- else:
312
- data = load_dataset(data_path)
313
- data = data.rename_columns(data_col_dict or {})
314
-
315
- valid_data = None
316
- train_data_mix_in = None
317
- valid_data_mix_in = None
318
-
319
- if data_mix_in_path and data_mix_in_factor > 0:
320
- # get mix-in training/validation data - to keep model "sane"
321
- num_rows = data["train"].num_rows
322
- log("Loading mix-in dataset: %s" % data_mix_in_path)
323
- if "json" in data_mix_in_path:
324
- data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
325
- else:
326
- data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
327
- data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
328
- mix_in_rows = int(num_rows * data_mix_in_factor)
329
-
330
- if mix_in_rows > data_mix_in.num_rows:
331
- # duplicate rows if mix-in is smaller than required
332
- log("Duplicating mixin to compensate for its size for training size and mixin fraction")
333
- data_mix_in = concatenate_datasets([data_mix_in] * int(np.ceil(mix_in_rows / data_mix_in.num_rows)))
334
-
335
- # only get as much as we need to balance
336
- valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
337
- train_size = max(1, min(data_mix_in.num_rows - valid_size, mix_in_rows))
338
- mixin_small = data_mix_in.train_test_split(
339
- test_size=train_size + valid_size,
340
- shuffle=True, seed=np.random.randint(10000),
341
- )["test"]
342
- if valid_size:
343
- mixin_train_test = mixin_small.train_test_split(
344
- test_size=valid_size, shuffle=False,
345
- )
346
- train_data_mix_in = mixin_train_test["train"]
347
- valid_data_mix_in = mixin_train_test["test"]
348
- else:
349
- train_data_mix_in = mixin_small
350
-
351
- if "prompt_type" not in train_data_mix_in.column_names:
352
- train_data_mix_in = train_data_mix_in.add_column(
353
- "prompt_type",
354
- [data_mix_in_prompt_type] * train_data_mix_in.num_rows,
355
- )
356
- log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
357
- if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
358
- valid_data_mix_in = valid_data_mix_in.add_column(
359
- "prompt_type",
360
- [data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
361
- )
362
- log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
363
- log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
364
-
365
- # get our own training/validation data - for fine-tuning
366
- if val_set_size > 0 and not valid_path and not data_mix_in_path:
367
- # create valid split from train
368
- train_val = data["train"].train_test_split(
369
- test_size=val_set_size, shuffle=True, seed=42
370
- )
371
- train_data = train_val["train"]
372
- valid_data = train_val["test"]
373
- else:
374
- train_data = data["train"]
375
- if valid_path:
376
- # use given valid split, has priority over data_mix_in_path
377
- valid_data = data["valid"]
378
- if "prompt_type" not in train_data.column_names:
379
- train_data = train_data.add_column(
380
- "prompt_type",
381
- [prompt_type] * train_data.num_rows,
382
- )
383
- log("Added prompt type %s to training data" % prompt_type)
384
- if valid_data and "prompt_type" not in valid_data.column_names:
385
- valid_data = valid_data.add_column(
386
- "prompt_type",
387
- [prompt_type] * valid_data.num_rows,
388
- )
389
- log("Added prompt type %s to validation data" % prompt_type)
390
-
391
- assert train_data is not None
392
-
393
- generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
394
- train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
395
- cutoff_len=cutoff_len, tokenizer=tokenizer)
396
-
397
- # shuffle and tokenize data
398
- if train_data_mix_in:
399
- train_data = concatenate_datasets([train_data, train_data_mix_in])
400
- log("Tokenizing %s training rows" % train_data.num_rows)
401
- train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count())
402
- if drop_truncations:
403
- log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
404
- prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
405
- train_data = train_data.filter(prune_long_sequences_func, num_proc=os.cpu_count() // torch.cuda.device_count())
406
- log("avoid keeping truncated cases to avoid contaminating model with truncation cases. New size: %s" % train_data.num_rows)
407
- train_set_size = len(train_data)
408
-
409
- if valid_data and valid_data_mix_in:
410
- valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
411
- elif valid_data_mix_in:
412
- valid_data = valid_data_mix_in
413
-
414
- if valid_data:
415
- log("Tokenizing %s validation rows" % valid_data.num_rows)
416
- valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count())
417
- val_set_size = len(valid_data)
418
- else:
419
- val_set_size = 0
420
- log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
421
- sample_row_dict = train_data[:1]
422
- del sample_row_dict['input_ids']
423
- del sample_row_dict['attention_mask']
424
- del sample_row_dict['labels']
425
- log("Sample input: %s" % sample_row_dict)
426
-
427
- try:
428
- import neptune
429
- from transformers.integrations import NeptuneCallback
430
-
431
- neptune_run = neptune.init_run(
432
- source_files=[],
433
- )
434
- log("Connected to Neptune.")
435
- except ImportError:
436
- neptune_run = None
437
- log("Please pip install neptune for tracking.")
438
- except neptune.exceptions.NeptuneMissingApiTokenException:
439
- neptune_run = None
440
- os.environ["NEPTUNE_MODE"] = 'debug'
441
- log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
442
-
443
- if neptune_run:
444
- neptune_callback = NeptuneCallback(run=neptune_run)
445
- callbacks = [neptune_callback]
446
- else:
447
- from transformers.integrations import TensorBoardCallback, is_tensorboard_available
448
- if is_tensorboard_available:
449
- # tensorboard --logdir=runs/
450
- from torch.utils.tensorboard import SummaryWriter
451
- tb_writer = SummaryWriter()
452
- callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
453
- else:
454
- callbacks = []
455
-
456
- expected_steps = (train_set_size * num_epochs) // batch_size
457
- if eval_steps is None and eval_epochs is None:
458
- # 20 evaluations for a run
459
- eval_steps = max(1, int(expected_steps / 20))
460
- log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
461
- elif eval_steps is None and eval_epochs is not None:
462
- eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
463
- log("Auto converted eval_epochs=%s to eval_steps %s"
464
- " out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
465
- if save_steps is None:
466
- save_steps = eval_steps
467
- log("Auto step save_steps to %s" % save_steps)
468
- elif save_steps > eval_steps:
469
- # save steps must be round multiple of eval_steps
470
- save_steps0 = save_steps
471
- save_steps = max(1, (save_steps//eval_steps)) * eval_steps
472
- if save_steps0 != save_steps:
473
- log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
474
-
475
- def compute_metrics(eval_preds):
476
- # e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
477
- inputs = eval_preds.inputs
478
- label_ids = eval_preds.label_ids
479
- predictions = eval_preds.predictions
480
-
481
- #inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
482
- #decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
483
- #decoded_inputs = [pred.strip() for pred in decoded_inputs]
484
-
485
- label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
486
- # tokenizer behavior like generate time
487
- decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
488
- clean_up_tokenization_spaces=True)
489
- decoded_labels = [pred.strip() for pred in decoded_labels]
490
-
491
- predictions = np.argmax(predictions, -1)
492
- predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
493
- # tokenizer behavior like generate time
494
- decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
495
- clean_up_tokenization_spaces=True)
496
- decoded_predictions = [pred.strip() for pred in decoded_predictions]
497
-
498
- result = {}
499
- for metric in metrics.values():
500
- result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
501
- # get rid of lists, for precision etc., for now
502
- numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
503
- result.update(numeric_results)
504
- return result
505
-
506
- # the callback that computes metrics of interest
507
- if val_metrics:
508
- trainer_kwargs = dict(compute_metrics=compute_metrics)
509
- else:
510
- trainer_kwargs = dict()
511
-
512
- import transformers
513
- trainer = transformers.Trainer(
514
- model=model,
515
- tokenizer=tokenizer,
516
- train_dataset=train_data,
517
- eval_dataset=valid_data,
518
- # FIXME: might need Seq2SeqTrainingArguments for some models
519
- args=transformers.TrainingArguments(
520
- per_device_train_batch_size=micro_batch_size,
521
- per_device_eval_batch_size=1,
522
- eval_accumulation_steps=10,
523
- # predict_with_generate=True, # SEQ2SEQ only
524
- include_inputs_for_metrics=True,
525
- gradient_accumulation_steps=gradient_accumulation_steps,
526
- warmup_steps=warmup_steps,
527
- num_train_epochs=num_epochs,
528
- learning_rate=learning_rate,
529
- gradient_checkpointing=gradient_checkpointing,
530
- fp16=fp16,
531
- # cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
532
- optim="adamw_torch", # consider "adafactor" to save memory
533
- logging_steps=logging_steps,
534
- logging_strategy="steps",
535
- evaluation_strategy="steps" if val_set_size > 0 else "no",
536
- save_strategy="steps",
537
- eval_steps=eval_steps if val_set_size > 0 else None,
538
- save_steps=save_steps,
539
- output_dir=output_dir,
540
- save_total_limit=save_total_limit,
541
- load_best_model_at_end=True if val_set_size > 0 else False,
542
- ddp_find_unused_parameters=False if ddp else None,
543
- group_by_length=group_by_length,
544
- #fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
545
- #fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
546
- report_to='tensorboard' if not neptune_run else 'neptune',
547
- ),
548
- data_collator=transformers.DataCollatorForSeq2Seq(
549
- tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
550
- ),
551
- callbacks=callbacks,
552
- **trainer_kwargs,
553
- )
554
- model.config.use_cache = False
555
-
556
- old_state_dict = model.state_dict
557
- from peft import get_peft_model_state_dict
558
-
559
- model.state_dict = (
560
- lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
561
- ).__get__(model, type(model))
562
-
563
- if torch.__version__ >= "2" and sys.platform != "win32":
564
- model = torch.compile(model)
565
- # WIP (not generally replacing layers until pytorch 2.1)
566
- if not llama_flash_attn:
567
- torch.backends.cuda.enable_flash_sdp(True)
568
-
569
- if gpus > 1 and not ddp:
570
- assert trainer.is_model_parallel
571
- else:
572
- assert not trainer.is_model_parallel
573
- trainer.train(resume_from_checkpoint=resume_from_checkpoint)
574
-
575
- model.save_pretrained(output_dir)
576
-
577
- log("\n If there's a warning about missing keys above, please disregard :)")
578
-
579
-
580
- def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
581
- # there's probably a way to do this with the tokenizer settings
582
- # but again, gotta move fast
583
- result = tokenizer(
584
- prompt,
585
- truncation=True,
586
- max_length=cutoff_len,
587
- padding=False,
588
- return_tensors=None,
589
- )
590
- if (
591
- result["input_ids"][-1] != tokenizer.eos_token_id
592
- and len(result["input_ids"]) < cutoff_len
593
- and add_eos_token
594
- ):
595
- result["input_ids"].append(tokenizer.eos_token_id)
596
- result["attention_mask"].append(1)
597
-
598
- result["labels"] = result["input_ids"].copy()
599
-
600
- return result
601
-
602
-
603
- def prune_long_sequences(data_point, cutoff_len=None):
604
- """
605
- Prune if too long for tokenizer, so truncation doesn't lead training to learn from truncated language
606
- :param data_point:
607
- :param cutoff_len:
608
- :return:
609
- """
610
- assert cutoff_len is not None
611
- return len(data_point['input_ids']) < cutoff_len
612
-
613
-
614
- def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=False, add_eos_token=False,
615
- cutoff_len=None, tokenizer=None):
616
- assert prompt_type is not None
617
- assert cutoff_len is not None
618
- assert tokenizer is not None
619
- full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, False, False)
620
- tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
621
- if not train_on_inputs:
622
- user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
623
- tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
624
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
625
- if add_eos_token:
626
- user_prompt_len -= 1
627
-
628
- # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
629
- tokenized_full_prompt["labels"] = [
630
- -100
631
- ] * user_prompt_len + tokenized_full_prompt["labels"][
632
- user_prompt_len:
633
- ] # could be sped up, probably
634
- return tokenized_full_prompt
635
-
636
-
637
- def test_debug():
638
- fire.Fire(train)
639
-
640
-
641
- if __name__ == "__main__":
642
- CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
643
- CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
644
- log(f"""
645
- Example runs on 4 GPUs:
646
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
647
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
648
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
649
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
650
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
651
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
652
-
653
- All metrics:
654
- CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
655
-
656
- # Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
657
- rippa>
658
- NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
659
- ova>
660
- NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
661
- timemachine>
662
- NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
663
-
664
- """, flush=True)
665
-
666
- if os.environ.get("LOCAL_RANK") is None:
667
- # then not using torchrun, so can't do distributed, ensure CVD set
668
- assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
669
-
670
- fire.Fire(train)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetune.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../finetune.py
generate.py DELETED
@@ -1,1548 +0,0 @@
1
- import ast
2
- import functools
3
- import glob
4
- import inspect
5
- import queue
6
- import shutil
7
- import sys
8
- import os
9
- import time
10
- import traceback
11
- import typing
12
- import warnings
13
- from datetime import datetime
14
- import filelock
15
- import psutil
16
-
17
- os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
18
- os.environ['BITSANDBYTES_NOWELCOME'] = '1'
19
- warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
20
-
21
- from loaders import get_loaders
22
- from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
23
- import_matplotlib, get_device, makedirs, get_kwargs
24
-
25
- import_matplotlib()
26
-
27
- SEED = 1236
28
- set_seed(SEED)
29
-
30
- from typing import Union
31
-
32
- import fire
33
- import torch
34
- from peft import PeftModel
35
- from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
36
- from accelerate import init_empty_weights, infer_auto_device_map
37
-
38
- from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types
39
- from stopping import get_stopping
40
-
41
- eval_extra_columns = ['prompt', 'response', 'score']
42
-
43
- langchain_modes = ['Disabled', 'ChatLLM', 'LLM', 'All', 'wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT',
44
- 'DriverlessAI docs']
45
-
46
- scratch_base_dir = '/tmp/'
47
-
48
-
49
- def main(
50
- load_8bit: bool = False,
51
- load_4bit: bool = False,
52
- load_half: bool = True,
53
- infer_devices: bool = True,
54
- base_model: str = '',
55
- tokenizer_base_model: str = '',
56
- lora_weights: str = "",
57
- gpu_id: int = 0,
58
- compile_model: bool = True,
59
-
60
- prompt_type: Union[int, str] = None,
61
- # input to generation
62
- temperature: float = None,
63
- top_p: float = None,
64
- top_k: int = None,
65
- num_beams: int = None,
66
- repetition_penalty: float = None,
67
- num_return_sequences: int = None,
68
- do_sample: bool = None,
69
- max_new_tokens: int = None,
70
- min_new_tokens: int = None,
71
- early_stopping: Union[bool, str] = None,
72
- max_time: float = None,
73
-
74
- memory_restriction_level: int = None,
75
- debug: bool = False,
76
- save_dir: str = None,
77
- share: bool = True,
78
- local_files_only: bool = False,
79
- resume_download: bool = True,
80
- use_auth_token: Union[str, bool] = False,
81
- trust_remote_code: Union[str, bool] = True,
82
- offload_folder: str = "offline_folder",
83
-
84
- src_lang: str = "English",
85
- tgt_lang: str = "Russian",
86
-
87
- cli: bool = False,
88
- cli_loop: bool = True,
89
- gradio: bool = True,
90
- gradio_avoid_processing_markdown: bool = False,
91
- gradio_offline_level: int = 0,
92
- chat: bool = True,
93
- chat_context: bool = False,
94
- stream_output: bool = True,
95
- show_examples: bool = None,
96
- verbose: bool = False,
97
- h2ocolors: bool = False,
98
- height: int = 600,
99
- show_lora: bool = True,
100
- login_mode_if_model0: bool = False,
101
- block_gradio_exit: bool = True,
102
- concurrency_count: int = 1,
103
- api_open: bool = False,
104
- allow_api: bool = True,
105
- input_lines: int = 1,
106
- auth: typing.List[typing.Tuple[str, str]] = None,
107
-
108
- sanitize_user_prompt: bool = True,
109
- sanitize_bot_response: bool = True,
110
-
111
- extra_model_options: typing.List[str] = [],
112
- extra_lora_options: typing.List[str] = [],
113
-
114
- score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
115
- auto_score: bool = True,
116
-
117
- eval_filename: str = None,
118
- eval_prompts_only_num: int = 0,
119
- eval_prompts_only_seed: int = 1234,
120
- eval_as_output: bool = False,
121
-
122
- langchain_mode: str = 'Disabled',
123
- visible_langchain_modes: list = ['UserData', 'MyData'],
124
- document_choice: list = ['All'],
125
- user_path: str = None,
126
- detect_user_path_changes_every_query: bool = False,
127
- load_db_if_exists: bool = True,
128
- keep_sources_in_context: bool = False,
129
- db_type: str = 'chroma',
130
- use_openai_embedding: bool = False,
131
- use_openai_model: bool = False,
132
- hf_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
133
- allow_upload_to_user_data: bool = True,
134
- allow_upload_to_my_data: bool = True,
135
- enable_url_upload: bool = True,
136
- enable_text_upload: bool = True,
137
- enable_sources_list: bool = True,
138
- chunk: bool = True,
139
- chunk_size: int = 512,
140
- top_k_docs: int = 3, # FIXME: Can go back to 4 once https://github.com/h2oai/h2ogpt/issues/192 fixed
141
- n_jobs: int = -1,
142
- enable_captions: bool = True,
143
- captions_model: str = "Salesforce/blip-image-captioning-base",
144
- pre_load_caption_model: bool = False,
145
- caption_gpu: bool = True,
146
- enable_ocr: bool = False,
147
- ):
148
- """
149
-
150
- :param load_8bit: load model in 8-bit using bitsandbytes
151
- :param load_4bit: load model in 4-bit using bitsandbytes
152
- :param load_half: load model in float16
153
- :param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
154
- :param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
155
- :param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
156
- :param lora_weights: LORA weights path/HF link
157
- :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
158
- :param compile_model Whether to compile the model
159
- :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
160
- :param temperature: generation temperature
161
- :param top_p: generation top_p
162
- :param top_k: generation top_k
163
- :param num_beams: generation number of beams
164
- :param repetition_penalty: generation repetition penalty
165
- :param num_return_sequences: generation number of sequences (1 forced for chat)
166
- :param do_sample: generation sample
167
- :param max_new_tokens: generation max new tokens
168
- :param min_new_tokens: generation min tokens
169
- :param early_stopping: generation early stopping
170
- :param max_time: maximum time to allow for generation
171
- :param memory_restriction_level: 0 = no restriction to tokens or model, 1 = some restrictions on token 2 = HF like restriction 3 = very low memory case
172
- :param debug: enable debug mode
173
- :param save_dir: directory chat data is saved to
174
- :param share: whether to share the gradio app with sharable URL
175
- :param local_files_only: whether to only use local files instead of doing to HF for models
176
- :param resume_download: whether to resume downloads from HF for models
177
- :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
178
- :param trust_remote_code: whether to use trust any code needed for HF model
179
- :param offload_folder: path for spilling model onto disk
180
- :param src_lang: source languages to include if doing translation (None = all)
181
- :param tgt_lang: target languages to include if doing translation (None = all)
182
- :param cli: whether to use CLI (non-gradio) interface.
183
- :param cli_loop: whether to loop for CLI (False usually only for testing)
184
- :param gradio: whether to enable gradio, or to enable benchmark mode
185
- :param gradio_avoid_processing_markdown:
186
- :param gradio_offline_level: > 0, then change fonts so full offline
187
- == 1 means backend won't need internet for fonts, but front-end UI might if font not cached
188
- == 2 means backend and frontend don't need internet to download any fonts.
189
- Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
190
- This option further disables google fonts for downloading, which is less intrusive than uploading,
191
- but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
192
- :param chat: whether to enable chat mode with chat history
193
- :param chat_context: whether to use extra helpful context if human_bot
194
- :param stream_output: whether to stream output from generate
195
- :param show_examples: whether to show clickable examples in gradio
196
- :param verbose: whether to show verbose prints
197
- :param h2ocolors: whether to use H2O.ai theme
198
- :param height: height of chat window
199
- :param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
200
- :param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
201
- :param block_gradio_exit: whether to block gradio exit (used for testing)
202
- :param concurrency_count: gradio concurrency count (1 is optimal for LLMs)
203
- :param api_open: If False, don't let API calls skip gradio queue
204
- :param allow_api: whether to allow API calls at all to gradio server
205
- :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
206
- :param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
207
- e.g. --auth=[('jon','password')] with no spaces
208
- :param sanitize_user_prompt: whether to remove profanity from user input
209
- :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output
210
- :param extra_model_options: extra models to show in list in gradio
211
- :param extra_lora_options: extra LORA to show in list in gradio
212
- :param score_model: which model to score responses (None means no scoring)
213
- :param auto_score: whether to automatically score responses
214
- :param eval_filename: json file to use for evaluation, if None is sharegpt
215
- :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
216
- :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
217
- :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
218
- :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
219
- WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
220
- :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
221
- If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
222
- :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
223
- Expensive for large number of files, so not done by default. By default only detect changes during db loading.
224
- :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
225
- Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
226
- But wiki_full is expensive and requires preparation
227
- To allow scratch space only live in session, add 'MyData' to list
228
- Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
229
- FIXME: Avoid 'All' for now, not implemented
230
- :param document_choice: Default document choice when taking subset of collection
231
- :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
232
- :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
233
- :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
234
- :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
235
- :param use_openai_model: Whether to use OpenAI model for use with vector db
236
- :param hf_embedding_model: Which HF embedding model to use for vector db
237
- :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
238
- :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
239
- :param enable_url_upload: Whether to allow upload from URL
240
- :param enable_text_upload: Whether to allow upload of text
241
- :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
242
- :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
243
- :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
244
- :param top_k_docs: number of chunks to give LLM
245
- :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
246
- :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
247
- :param captions_model: Which model to use for captions.
248
- captions_model: int = "Salesforce/blip-image-captioning-base", # continue capable
249
- captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
250
- captions_model: int = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
251
- Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
252
- :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
253
- parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
254
- Recommended if using larger caption model
255
- :param caption_gpu: If support caption, then use GPU if exists
256
- :param enable_ocr: Whether to support OCR on images
257
- :return:
258
- """
259
- is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
260
- is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
261
- is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
262
- if memory_restriction_level is None:
263
- memory_restriction_level = 2 if is_hf else 0 # 2 assumes run on 24GB consumer GPU
264
- else:
265
- assert 0 <= memory_restriction_level <= 3, "Bad memory_restriction_level=%s" % memory_restriction_level
266
- admin_pass = os.getenv("ADMIN_PASS")
267
- # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
268
- # but becomes unrecoverable sometimes if raise, so just be silent for now
269
- raise_generate_gpu_exceptions = True
270
-
271
- # allow set token directly
272
- use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
273
- allow_upload_to_user_data = bool(os.environ.get("allow_upload_to_user_data", allow_upload_to_user_data))
274
- allow_upload_to_my_data = bool(os.environ.get("allow_upload_to_my_data", allow_upload_to_my_data))
275
- height = os.environ.get("HEIGHT", height)
276
-
277
- # allow enabling langchain via ENV
278
- # FIRST PLACE where LangChain referenced, but no imports related to it
279
- langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
280
- assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
281
- visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes)))
282
- if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
283
- visible_langchain_modes += [langchain_mode]
284
-
285
- if is_public:
286
- allow_upload_to_user_data = False
287
- input_lines = 1 # ensure set, for ease of use
288
- temperature = 0.2 if temperature is None else temperature
289
- top_p = 0.85 if top_p is None else top_p
290
- top_k = 70 if top_k is None else top_k
291
- if is_hf:
292
- do_sample = True if do_sample is None else do_sample
293
- else:
294
- # by default don't sample, too chatty
295
- do_sample = False if do_sample is None else do_sample
296
-
297
- if memory_restriction_level == 2:
298
- if not base_model:
299
- base_model = 'h2oai/h2ogpt-oasst1-512-12b'
300
- # don't set load_8bit if passed base_model, doesn't always work so can't just override
301
- load_8bit = True
302
- load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
303
- else:
304
- base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
305
- if memory_restriction_level >= 2:
306
- load_8bit = True
307
- load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
308
- if is_hf:
309
- # must override share if in spaces
310
- share = False
311
- save_dir = os.getenv('SAVE_DIR', save_dir)
312
- score_model = os.getenv('SCORE_MODEL', score_model)
313
- if score_model == 'None' or score_model is None:
314
- score_model = ''
315
- concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
316
- api_open = bool(int(os.getenv('API_OPEN', api_open)))
317
- allow_api = bool(int(os.getenv('ALLOW_API', allow_api)))
318
-
319
- n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
320
- if n_gpus == 0:
321
- gpu_id = None
322
- load_8bit = False
323
- load_4bit = False
324
- load_half = False
325
- infer_devices = False
326
- torch.backends.cudnn.benchmark = True
327
- torch.backends.cudnn.enabled = False
328
- torch.set_default_dtype(torch.float32)
329
- if psutil.virtual_memory().available < 94 * 1024 ** 3:
330
- # 12B uses ~94GB
331
- # 6.9B uses ~47GB
332
- base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
333
-
334
- # get defaults
335
- model_lower = base_model.lower()
336
- if not gradio:
337
- # force, else not single response like want to look at
338
- stream_output = False
339
- # else prompt removal can mess up output
340
- chat = False
341
- # hard-coded defaults
342
- first_para = False
343
- text_limit = None
344
-
345
- if offload_folder:
346
- makedirs(offload_folder)
347
-
348
- placeholder_instruction, placeholder_input, \
349
- stream_output, show_examples, \
350
- prompt_type, temperature, top_p, top_k, num_beams, \
351
- max_new_tokens, min_new_tokens, early_stopping, max_time, \
352
- repetition_penalty, num_return_sequences, \
353
- do_sample, \
354
- src_lang, tgt_lang, \
355
- examples, \
356
- task_info = \
357
- get_generate_params(model_lower, chat,
358
- stream_output, show_examples,
359
- prompt_type, temperature, top_p, top_k, num_beams,
360
- max_new_tokens, min_new_tokens, early_stopping, max_time,
361
- repetition_penalty, num_return_sequences,
362
- do_sample,
363
- top_k_docs,
364
- verbose,
365
- )
366
-
367
- locals_dict = locals()
368
- locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
369
- if verbose:
370
- print(f"Generating model with params:\n{locals_print}", flush=True)
371
- print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
372
-
373
- if langchain_mode != "Disabled":
374
- # SECOND PLACE where LangChain referenced, but all imports are kept local so not required
375
- from gpt_langchain import prep_langchain, get_some_dbs_from_hf
376
- if is_hf:
377
- get_some_dbs_from_hf()
378
- dbs = {}
379
- for langchain_mode1 in visible_langchain_modes:
380
- if langchain_mode1 in ['MyData']:
381
- # don't use what is on disk, remove it instead
382
- for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
383
- if os.path.isdir(gpath1):
384
- print("Removing old MyData: %s" % gpath1, flush=True)
385
- shutil.rmtree(gpath1)
386
- continue
387
- if langchain_mode1 in ['All']:
388
- # FIXME: All should be avoided until scans over each db, shouldn't be separate db
389
- continue
390
- persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case
391
- db = prep_langchain(persist_directory1,
392
- load_db_if_exists,
393
- db_type, use_openai_embedding,
394
- langchain_mode1, user_path,
395
- hf_embedding_model,
396
- kwargs_make_db=locals())
397
- dbs[langchain_mode1] = db
398
- # remove None db's so can just rely upon k in dbs for if hav db
399
- dbs = {k: v for k, v in dbs.items() if v is not None}
400
- else:
401
- dbs = {}
402
- # import control
403
- if os.environ.get("TEST_LANGCHAIN_IMPORT"):
404
- assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
405
- assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
406
-
407
- if cli:
408
- from cli import run_cli
409
- return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals()))
410
- elif not gradio:
411
- from eval import run_eval
412
- return run_eval(**get_kwargs(run_eval, exclude_names=['model_state0'], **locals()))
413
- elif gradio:
414
- # imported here so don't require gradio to run generate
415
- from gradio_runner import go_gradio
416
-
417
- # get default model
418
- all_kwargs = locals().copy()
419
- if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
420
- model0, tokenizer0, device = get_model(reward_type=False,
421
- **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs))
422
- else:
423
- # if empty model, then don't load anything, just get gradio up
424
- model0, tokenizer0, device = None, None, None
425
- model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
426
-
427
- # get score model
428
- smodel, stokenizer, sdevice = get_score_model(reward_type=True,
429
- **get_kwargs(get_score_model, exclude_names=['reward_type'],
430
- **all_kwargs))
431
- score_model_state0 = [smodel, stokenizer, sdevice, score_model]
432
-
433
- if enable_captions:
434
- if pre_load_caption_model:
435
- from image_captions import H2OImageCaptionLoader
436
- caption_loader = H2OImageCaptionLoader(caption_gpu=caption_gpu).load_model()
437
- else:
438
- caption_loader = 'gpu' if caption_gpu else 'cpu'
439
- else:
440
- caption_loader = False
441
-
442
- # assume gradio needs everything
443
- go_gradio(**locals())
444
-
445
-
446
- def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
447
- gpu_id=0,
448
- use_auth_token=False,
449
- trust_remote_code=True,
450
- offload_folder=None,
451
- triton_attn=False,
452
- long_sequence=True,
453
- ):
454
- """
455
- Ensure model gets on correct device
456
- :param base_model:
457
- :param model_loader:
458
- :param load_half:
459
- :param model_kwargs:
460
- :param reward_type:
461
- :param gpu_id:
462
- :param use_auth_token:
463
- :param trust_remote_code:
464
- :param offload_folder:
465
- :param triton_attn:
466
- :param long_sequence:
467
- :return:
468
- """
469
- with init_empty_weights():
470
- from transformers import AutoConfig
471
- config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
472
- trust_remote_code=trust_remote_code,
473
- offload_folder=offload_folder)
474
- if triton_attn and 'mpt-' in base_model.lower():
475
- config.attn_config['attn_impl'] = 'triton'
476
- if long_sequence:
477
- if 'mpt-7b-storywriter' in base_model.lower():
478
- config.update({"max_seq_len": 83968})
479
- if 'mosaicml/mpt-7b-chat' in base_model.lower():
480
- config.update({"max_seq_len": 4096})
481
- if issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
482
- model = AutoModel.from_config(
483
- config,
484
- )
485
- else:
486
- # can't infer
487
- model = None
488
-
489
- if model is not None:
490
- # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
491
- # NOTE: Some models require avoiding sharding some layers,
492
- # then would pass no_split_module_classes and give list of those layers.
493
- device_map = infer_auto_device_map(
494
- model,
495
- dtype=torch.float16 if load_half else torch.float32,
496
- )
497
- if hasattr(model, 'model'):
498
- device_map_model = infer_auto_device_map(
499
- model.model,
500
- dtype=torch.float16 if load_half else torch.float32,
501
- )
502
- device_map.update(device_map_model)
503
- else:
504
- device_map = "auto"
505
-
506
- n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
507
-
508
- if n_gpus > 0:
509
- if gpu_id >= 0:
510
- # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
511
- # So avoid for now, just put on first GPU, unless score_model, put on last
512
- if reward_type:
513
- device_map = {'': n_gpus - 1}
514
- else:
515
- device_map = {'': min(n_gpus - 1, gpu_id)}
516
- if gpu_id == -1:
517
- device_map = {'': 'cuda'}
518
- else:
519
- device_map = {'': 'cpu'}
520
- model_kwargs['load_in_8bit'] = False
521
- model_kwargs['load_in_4bit'] = False
522
- print('device_map: %s' % device_map, flush=True)
523
-
524
- load_in_8bit = model_kwargs.get('load_in_8bit', False)
525
- load_in_4bit = model_kwargs.get('load_in_4bit', False)
526
- model_kwargs['device_map'] = device_map
527
- pop_unused_model_kwargs(model_kwargs)
528
-
529
- if load_in_8bit or load_in_4bit or not load_half:
530
- model = model_loader.from_pretrained(
531
- base_model,
532
- config=config,
533
- **model_kwargs,
534
- )
535
- else:
536
- model = model_loader.from_pretrained(
537
- base_model,
538
- config=config,
539
- **model_kwargs,
540
- ).half()
541
- return model
542
-
543
-
544
- def get_model(
545
- load_8bit: bool = False,
546
- load_4bit: bool = False,
547
- load_half: bool = True,
548
- infer_devices: bool = True,
549
- base_model: str = '',
550
- tokenizer_base_model: str = '',
551
- lora_weights: str = "",
552
- gpu_id: int = 0,
553
-
554
- reward_type: bool = None,
555
- local_files_only: bool = False,
556
- resume_download: bool = True,
557
- use_auth_token: Union[str, bool] = False,
558
- trust_remote_code: bool = True,
559
- offload_folder: str = None,
560
- compile_model: bool = True,
561
-
562
- verbose: bool = False,
563
- ):
564
- """
565
-
566
- :param load_8bit: load model in 8-bit, not supported by all models
567
- :param load_4bit: load model in 4-bit, not supported by all models
568
- :param load_half: load model in 16-bit
569
- :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
570
- For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
571
- So it is not the default
572
- :param base_model: name/path of base model
573
- :param tokenizer_base_model: name/path of tokenizer
574
- :param lora_weights: name/path
575
- :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
576
- :param reward_type: reward type model for sequence classification
577
- :param local_files_only: use local files instead of from HF
578
- :param resume_download: resume downloads from HF
579
- :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
580
- :param trust_remote_code: trust code needed by model
581
- :param offload_folder: offload folder
582
- :param compile_model: whether to compile torch model
583
- :param verbose:
584
- :return:
585
- """
586
- if verbose:
587
- print("Get %s model" % base_model, flush=True)
588
- if base_model in non_hf_types:
589
- from gpt4all_llm import get_model_tokenizer_gpt4all
590
- model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
591
- return model, tokenizer, device
592
-
593
- if lora_weights is not None and lora_weights.strip():
594
- if verbose:
595
- print("Get %s lora weights" % lora_weights, flush=True)
596
- device = get_device()
597
-
598
- if 'gpt2' in base_model.lower():
599
- # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
600
- load_8bit = False
601
- load_4bit = False
602
-
603
- assert base_model.strip(), (
604
- "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
605
- )
606
-
607
- from transformers import AutoConfig
608
- config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
609
- trust_remote_code=trust_remote_code,
610
- offload_folder=offload_folder)
611
- llama_type_from_config = 'llama' in str(config).lower()
612
- llama_type_from_name = "llama" in base_model.lower()
613
- llama_type = llama_type_from_config or llama_type_from_name
614
- if llama_type:
615
- if verbose:
616
- print("Detected as llama type from"
617
- " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
618
-
619
- model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
620
- if not tokenizer_base_model:
621
- tokenizer_base_model = base_model
622
-
623
- if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
624
- tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
625
- local_files_only=local_files_only,
626
- resume_download=resume_download,
627
- use_auth_token=use_auth_token,
628
- trust_remote_code=trust_remote_code,
629
- offload_folder=offload_folder,
630
- )
631
- else:
632
- tokenizer = tokenizer_loader
633
-
634
- if isinstance(tokenizer, str):
635
- # already a pipeline, tokenizer_loader is string for task
636
- model = model_loader(tokenizer,
637
- model=base_model,
638
- device=0 if device == "cuda" else -1,
639
- torch_dtype=torch.float16 if device == 'cuda' else torch.float32)
640
- else:
641
- assert device in ["cuda", "cpu"], "Unsupported device %s" % device
642
- model_kwargs = dict(local_files_only=local_files_only,
643
- torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
644
- resume_download=resume_download,
645
- use_auth_token=use_auth_token,
646
- trust_remote_code=trust_remote_code,
647
- offload_folder=offload_folder,
648
- )
649
- if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
650
- model_kwargs.update(dict(load_in_8bit=load_8bit,
651
- load_in_4bit=load_4bit,
652
- device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
653
- ))
654
- if 'mpt-' in base_model.lower() and gpu_id >= 0:
655
- model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
656
-
657
- if 'OpenAssistant/reward-model'.lower() in base_model.lower():
658
- # FIXME: could put on other GPUs
659
- model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
660
- model_kwargs.pop('torch_dtype', None)
661
- pop_unused_model_kwargs(model_kwargs)
662
-
663
- if not lora_weights:
664
- with torch.device(device):
665
- if infer_devices:
666
- model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
667
- gpu_id=gpu_id,
668
- use_auth_token=use_auth_token,
669
- trust_remote_code=trust_remote_code,
670
- offload_folder=offload_folder,
671
- )
672
- else:
673
- if load_half and not (load_8bit or load_4bit):
674
- model = model_loader.from_pretrained(
675
- base_model,
676
- **model_kwargs).half()
677
- else:
678
- model = model_loader.from_pretrained(
679
- base_model,
680
- **model_kwargs)
681
- elif load_8bit or load_4bit:
682
- model = model_loader.from_pretrained(
683
- base_model,
684
- **model_kwargs
685
- )
686
- model = PeftModel.from_pretrained(
687
- model,
688
- lora_weights,
689
- torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
690
- local_files_only=local_files_only,
691
- resume_download=resume_download,
692
- use_auth_token=use_auth_token,
693
- trust_remote_code=trust_remote_code,
694
- offload_folder=offload_folder,
695
- device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
696
- )
697
- else:
698
- with torch.device(device):
699
- model = model_loader.from_pretrained(
700
- base_model,
701
- **model_kwargs
702
- )
703
- model = PeftModel.from_pretrained(
704
- model,
705
- lora_weights,
706
- torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
707
- local_files_only=local_files_only,
708
- resume_download=resume_download,
709
- use_auth_token=use_auth_token,
710
- trust_remote_code=trust_remote_code,
711
- offload_folder=offload_folder,
712
- device_map="auto",
713
- )
714
- if load_half:
715
- model.half()
716
-
717
- # unwind broken decapoda-research config
718
- if llama_type:
719
- model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
720
- model.config.bos_token_id = 1
721
- model.config.eos_token_id = 2
722
- if 'gpt2' in base_model.lower():
723
- # add special tokens that otherwise all share the same id
724
- tokenizer.add_special_tokens({'bos_token': '<bos>',
725
- 'eos_token': '<eos>',
726
- 'pad_token': '<pad>'})
727
-
728
- if not isinstance(tokenizer, str):
729
- model.eval()
730
- if torch.__version__ >= "2" and sys.platform != "win32" and compile_model:
731
- model = torch.compile(model)
732
-
733
- if hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int):
734
- # help automatically limit inputs to generate
735
- tokenizer.model_max_length = config.max_position_embeddings
736
- else:
737
- tokenizer.model_max_length = 2048
738
-
739
- return model, tokenizer, device
740
-
741
-
742
- def pop_unused_model_kwargs(model_kwargs):
743
- """
744
- in-place pop unused kwargs that are not dependency-upgrade friendly
745
- no point passing in False, is default, and helps avoid needing to update requirements for new deps
746
- :param model_kwargs:
747
- :return:
748
- """
749
- check_list = ['load_in_8bit', 'load_in_4bit']
750
- for k in check_list:
751
- if k in model_kwargs and not model_kwargs[k]:
752
- model_kwargs.pop(k)
753
-
754
-
755
- def get_score_model(score_model: str = None,
756
- load_8bit: bool = False,
757
- load_4bit: bool = False,
758
- load_half: bool = True,
759
- infer_devices: bool = True,
760
- base_model: str = '',
761
- tokenizer_base_model: str = '',
762
- lora_weights: str = "",
763
- gpu_id: int = 0,
764
-
765
- reward_type: bool = None,
766
- local_files_only: bool = False,
767
- resume_download: bool = True,
768
- use_auth_token: Union[str, bool] = False,
769
- trust_remote_code: bool = True,
770
- offload_folder: str = None,
771
- compile_model: bool = True,
772
-
773
- verbose: bool = False,
774
- ):
775
- if score_model is not None and score_model.strip():
776
- load_8bit = False
777
- load_4bit = False
778
- load_half = False
779
- base_model = score_model.strip()
780
- tokenizer_base_model = ''
781
- lora_weights = ''
782
- llama_type = False
783
- compile_model = False
784
- smodel, stokenizer, sdevice = get_model(reward_type=True,
785
- **get_kwargs(get_model, exclude_names=['reward_type'], **locals()))
786
- else:
787
- smodel, stokenizer, sdevice = None, None, None
788
- return smodel, stokenizer, sdevice
789
-
790
-
791
- eval_func_param_names = ['instruction',
792
- 'iinput',
793
- 'context',
794
- 'stream_output',
795
- 'prompt_type',
796
- 'temperature',
797
- 'top_p',
798
- 'top_k',
799
- 'num_beams',
800
- 'max_new_tokens',
801
- 'min_new_tokens',
802
- 'early_stopping',
803
- 'max_time',
804
- 'repetition_penalty',
805
- 'num_return_sequences',
806
- 'do_sample',
807
- 'chat',
808
- 'instruction_nochat',
809
- 'iinput_nochat',
810
- 'langchain_mode',
811
- 'top_k_docs',
812
- 'document_choice',
813
- ]
814
-
815
-
816
- def evaluate(
817
- model_state,
818
- my_db_state,
819
- # START NOTE: Examples must have same order of parameters
820
- instruction,
821
- iinput,
822
- context,
823
- stream_output,
824
- prompt_type,
825
- temperature,
826
- top_p,
827
- top_k,
828
- num_beams,
829
- max_new_tokens,
830
- min_new_tokens,
831
- early_stopping,
832
- max_time,
833
- repetition_penalty,
834
- num_return_sequences,
835
- do_sample,
836
- chat,
837
- instruction_nochat,
838
- iinput_nochat,
839
- langchain_mode,
840
- top_k_docs,
841
- document_choice,
842
- # END NOTE: Examples must have same order of parameters
843
- src_lang=None,
844
- tgt_lang=None,
845
- debug=False,
846
- concurrency_count=None,
847
- save_dir=None,
848
- sanitize_bot_response=True,
849
- model_state0=None,
850
- memory_restriction_level=None,
851
- raise_generate_gpu_exceptions=None,
852
- chat_context=None,
853
- lora_weights=None,
854
- load_db_if_exists=True,
855
- dbs=None,
856
- user_path=None,
857
- detect_user_path_changes_every_query=None,
858
- use_openai_embedding=None,
859
- use_openai_model=None,
860
- hf_embedding_model=None,
861
- chunk=None,
862
- chunk_size=None,
863
- db_type=None,
864
- n_jobs=None,
865
- first_para=None,
866
- text_limit=None,
867
- verbose=False,
868
- cli=False,
869
- ):
870
- # ensure passed these
871
- assert concurrency_count is not None
872
- assert memory_restriction_level is not None
873
- assert raise_generate_gpu_exceptions is not None
874
- assert chat_context is not None
875
- assert use_openai_embedding is not None
876
- assert use_openai_model is not None
877
- assert hf_embedding_model is not None
878
- assert chunk is not None
879
- assert chunk_size is not None
880
- assert db_type is not None
881
- assert top_k_docs is not None and isinstance(top_k_docs, int)
882
- assert n_jobs is not None
883
- assert first_para is not None
884
-
885
- if debug:
886
- locals_dict = locals().copy()
887
- locals_dict.pop('model_state', None)
888
- locals_dict.pop('model_state0', None)
889
- print(locals_dict)
890
-
891
- no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\nThen start New Conversation"
892
-
893
- if model_state0 is None:
894
- # e.g. for no gradio case, set dummy value, else should be set
895
- model_state0 = [None, None, None, None]
896
-
897
- if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
898
- # try to free-up original model (i.e. list was passed as reference)
899
- if model_state0 is not None and model_state0[0] is not None:
900
- model_state0[0].cpu()
901
- model_state0[0] = None
902
- # try to free-up original tokenizer (i.e. list was passed as reference)
903
- if model_state0 is not None and model_state0[1] is not None:
904
- model_state0[1] = None
905
- clear_torch_cache()
906
- model, tokenizer, device, base_model = model_state
907
- elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
908
- assert isinstance(model_state[0], str)
909
- model, tokenizer, device, base_model = model_state0
910
- else:
911
- raise AssertionError(no_model_msg)
912
-
913
- if base_model is None:
914
- raise AssertionError(no_model_msg)
915
-
916
- assert base_model.strip(), no_model_msg
917
- assert model, "Model is missing"
918
- assert tokenizer, "Tokenizer is missing"
919
-
920
- # choose chat or non-chat mode
921
- if not chat:
922
- instruction = instruction_nochat
923
- iinput = iinput_nochat
924
-
925
- if not context:
926
- # get hidden context if have one
927
- context = get_context(chat_context, prompt_type)
928
-
929
- prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
930
- data_point = dict(context=context, instruction=instruction, input=iinput)
931
- prompt = prompter.generate_prompt(data_point)
932
-
933
- # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
934
- assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
935
- if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
936
- db1 = my_db_state[0]
937
- elif dbs is not None and langchain_mode in dbs:
938
- db1 = dbs[langchain_mode]
939
- else:
940
- db1 = None
941
- if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in non_hf_types:
942
- query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
943
- outr = ""
944
- # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
945
- from gpt_langchain import run_qa_db
946
- for r in run_qa_db(query=query,
947
- model_name=base_model, model=model, tokenizer=tokenizer,
948
- stream_output=stream_output,
949
- prompter=prompter,
950
- load_db_if_exists=load_db_if_exists,
951
- db=db1,
952
- user_path=user_path,
953
- detect_user_path_changes_every_query=detect_user_path_changes_every_query,
954
- max_new_tokens=max_new_tokens,
955
- cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
956
- use_openai_embedding=use_openai_embedding,
957
- use_openai_model=use_openai_model,
958
- hf_embedding_model=hf_embedding_model,
959
- first_para=first_para,
960
- text_limit=text_limit,
961
- chunk=chunk,
962
- chunk_size=chunk_size,
963
- langchain_mode=langchain_mode,
964
- document_choice=document_choice,
965
- db_type=db_type,
966
- k=top_k_docs,
967
- temperature=temperature,
968
- repetition_penalty=repetition_penalty,
969
- top_k=top_k,
970
- top_p=top_p,
971
- prompt_type=prompt_type,
972
- n_jobs=n_jobs,
973
- verbose=verbose,
974
- cli=cli,
975
- ):
976
- outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer
977
- yield dict(response=outr, sources=extra)
978
- if save_dir:
979
- save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
980
- if verbose:
981
- print(
982
- 'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
983
- flush=True)
984
- if outr or base_model in non_hf_types:
985
- # if got no response (e.g. not showing sources and got no sources,
986
- # so nothing to give to LLM), then slip through and ask LLM
987
- # Or if llama/gptj, then just return since they had no response and can't go down below code path
988
- return
989
-
990
- if isinstance(tokenizer, str):
991
- # pipeline
992
- if tokenizer == "summarization":
993
- key = 'summary_text'
994
- else:
995
- raise RuntimeError("No such task type %s" % tokenizer)
996
- # NOTE: uses max_length only
997
- yield dict(response=model(prompt, max_length=max_new_tokens)[0][key], sources='')
998
-
999
- if 'mbart-' in base_model.lower():
1000
- assert src_lang is not None
1001
- tokenizer.src_lang = languages_covered()[src_lang]
1002
-
1003
- if chat:
1004
- # override, ignore user change
1005
- num_return_sequences = 1
1006
- stopping_criteria = get_stopping(prompt_type, tokenizer, device)
1007
- _, _, max_length_tokenize, max_prompt_length = get_cutoffs(memory_restriction_level, model_max_length=tokenizer.model_max_length)
1008
- prompt = prompt[-max_prompt_length:]
1009
- inputs = tokenizer(prompt,
1010
- return_tensors="pt",
1011
- truncation=True,
1012
- max_length=max_length_tokenize)
1013
- if inputs['input_ids'].shape[1] >= max_length_tokenize - 1:
1014
- print("Cutting off input: %s %s" % (inputs['input_ids'].shape[1], max_length_tokenize), flush=True)
1015
- if debug and len(inputs["input_ids"]) > 0:
1016
- print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1017
- input_ids = inputs["input_ids"].to(device)
1018
- # CRITICAL LIMIT else will fail
1019
- max_max_tokens = tokenizer.model_max_length
1020
- max_input_tokens = max_max_tokens - max_new_tokens
1021
- input_ids = input_ids[:, -max_input_tokens:]
1022
- generation_config = GenerationConfig(
1023
- temperature=float(temperature),
1024
- top_p=float(top_p),
1025
- top_k=top_k,
1026
- num_beams=num_beams,
1027
- do_sample=do_sample,
1028
- repetition_penalty=float(repetition_penalty),
1029
- num_return_sequences=num_return_sequences,
1030
- renormalize_logits=True,
1031
- remove_invalid_values=True,
1032
- )
1033
-
1034
- gen_kwargs = dict(input_ids=input_ids,
1035
- generation_config=generation_config,
1036
- return_dict_in_generate=True,
1037
- output_scores=True,
1038
- max_new_tokens=max_new_tokens, # prompt + new
1039
- min_new_tokens=min_new_tokens, # prompt + new
1040
- early_stopping=early_stopping, # False, True, "never"
1041
- max_time=max_time,
1042
- stopping_criteria=stopping_criteria,
1043
- )
1044
- if 'gpt2' in base_model.lower():
1045
- gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
1046
- elif 'mbart-' in base_model.lower():
1047
- assert tgt_lang is not None
1048
- tgt_lang = languages_covered()[tgt_lang]
1049
- gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
1050
- else:
1051
- gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
1052
-
1053
- decoder_kwargs = dict(skip_special_tokens=True,
1054
- clean_up_tokenization_spaces=True)
1055
-
1056
- decoder = functools.partial(tokenizer.decode,
1057
- **decoder_kwargs
1058
- )
1059
- decoder_raw_kwargs = dict(skip_special_tokens=False,
1060
- clean_up_tokenization_spaces=True)
1061
-
1062
- decoder_raw = functools.partial(tokenizer.decode,
1063
- **decoder_raw_kwargs
1064
- )
1065
-
1066
- with torch.no_grad():
1067
- context_class_cast = NullContext if device == 'cpu' or lora_weights else torch.autocast
1068
- with context_class_cast(device):
1069
- # protection for gradio not keeping track of closed users,
1070
- # else hit bitsandbytes lack of thread safety:
1071
- # https://github.com/h2oai/h2ogpt/issues/104
1072
- # but only makes sense if concurrency_count == 1
1073
- context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
1074
- if verbose:
1075
- print('Pre-Generate: %s' % str(datetime.now()), flush=True)
1076
- decoded_output = None
1077
- with context_class("generate.lock"):
1078
- if verbose:
1079
- print('Generate: %s' % str(datetime.now()), flush=True)
1080
- # decoded tokenized prompt can deviate from prompt due to special characters
1081
- inputs_decoded = decoder(input_ids[0])
1082
- inputs_decoded_raw = decoder_raw(input_ids[0])
1083
- if inputs_decoded == prompt:
1084
- # normal
1085
- pass
1086
- elif inputs_decoded.lstrip() == prompt.lstrip():
1087
- # sometimes extra space in front, make prompt same for prompt removal
1088
- prompt = inputs_decoded
1089
- elif inputs_decoded_raw == prompt:
1090
- # some models specify special tokens that are part of normal prompt, so can't skip them
1091
- inputs_decoded = prompt = inputs_decoded_raw
1092
- decoder = decoder_raw
1093
- decoder_kwargs = decoder_raw_kwargs
1094
- elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ',
1095
- '') == prompt.replace(
1096
- '\n', ' ').replace(' ', ''):
1097
- inputs_decoded = prompt = inputs_decoded_raw
1098
- decoder = decoder_raw
1099
- decoder_kwargs = decoder_raw_kwargs
1100
- else:
1101
- if verbose:
1102
- print("WARNING: Special characters in prompt", flush=True)
1103
- if stream_output:
1104
- skip_prompt = False
1105
- streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
1106
- **decoder_kwargs)
1107
- gen_kwargs.update(dict(streamer=streamer))
1108
- target = wrapped_partial(generate_with_exceptions, model.generate,
1109
- prompt=prompt, inputs_decoded=inputs_decoded,
1110
- raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
1111
- **gen_kwargs)
1112
- bucket = queue.Queue()
1113
- thread = EThread(target=target, streamer=streamer, bucket=bucket)
1114
- thread.start()
1115
- outputs = ""
1116
- try:
1117
- for new_text in streamer:
1118
- if bucket.qsize() > 0 or thread.exc:
1119
- thread.join()
1120
- outputs += new_text
1121
- yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded,
1122
- sanitize_bot_response=sanitize_bot_response),
1123
- sources='')
1124
- except BaseException:
1125
- # if any exception, raise that exception if was from thread, first
1126
- if thread.exc:
1127
- raise thread.exc
1128
- raise
1129
- finally:
1130
- # in case no exception and didn't join with thread yet, then join
1131
- if not thread.exc:
1132
- thread.join()
1133
- # in case raise StopIteration or broke queue loop in streamer, but still have exception
1134
- if thread.exc:
1135
- raise thread.exc
1136
- decoded_output = outputs
1137
- else:
1138
- outputs = model.generate(**gen_kwargs)
1139
- outputs = [decoder(s) for s in outputs.sequences]
1140
- yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded,
1141
- sanitize_bot_response=sanitize_bot_response), sources='')
1142
- if outputs and len(outputs) >= 1:
1143
- decoded_output = prompt + outputs[0]
1144
- if save_dir and decoded_output:
1145
- save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
1146
- if verbose:
1147
- print('Post-Generate: %s decoded_output: %s' % (
1148
- str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
1149
-
1150
-
1151
- inputs_list_names = list(inspect.signature(evaluate).parameters)
1152
- state_names = ['model_state', 'my_db_state']
1153
- inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
1154
-
1155
-
1156
- def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=2048):
1157
- # help to avoid errors like:
1158
- # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
1159
- # RuntimeError: expected scalar type Half but found Float
1160
- # with - 256
1161
- if memory_restriction_level > 0:
1162
- max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
1163
- else:
1164
- max_length_tokenize = model_max_length - 256
1165
- cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1166
- output_smallest = 30 * 4
1167
- max_prompt_length = cutoff_len - output_smallest
1168
-
1169
- if for_context:
1170
- # then lower even more to avoid later chop, since just estimate tokens in context bot
1171
- max_prompt_length = max(64, int(max_prompt_length * 0.8))
1172
-
1173
- return cutoff_len, output_smallest, max_length_tokenize, max_prompt_length
1174
-
1175
-
1176
- class H2OTextIteratorStreamer(TextIteratorStreamer):
1177
- """
1178
- normally, timeout required for now to handle exceptions, else get()
1179
- but with H2O version of TextIteratorStreamer, loop over block to handle
1180
- """
1181
-
1182
- def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None,
1183
- block=True, **decode_kwargs):
1184
- super().__init__(tokenizer, skip_prompt, **decode_kwargs)
1185
- self.text_queue = queue.Queue()
1186
- self.stop_signal = None
1187
- self.do_stop = False
1188
- self.timeout = timeout
1189
- self.block = block
1190
-
1191
- def on_finalized_text(self, text: str, stream_end: bool = False):
1192
- """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
1193
- self.text_queue.put(text, timeout=self.timeout)
1194
- if stream_end:
1195
- self.text_queue.put(self.stop_signal, timeout=self.timeout)
1196
-
1197
- def __iter__(self):
1198
- return self
1199
-
1200
- def __next__(self):
1201
- while True:
1202
- try:
1203
- value = self.stop_signal # value looks unused in pycharm, not true
1204
- if self.do_stop:
1205
- print("hit stop", flush=True)
1206
- # could raise or break, maybe best to raise and make parent see if any exception in thread
1207
- raise StopIteration()
1208
- # break
1209
- value = self.text_queue.get(block=self.block, timeout=self.timeout)
1210
- break
1211
- except queue.Empty:
1212
- time.sleep(0.01)
1213
- if value == self.stop_signal:
1214
- raise StopIteration()
1215
- else:
1216
- return value
1217
-
1218
-
1219
- def generate_with_exceptions(func, *args, prompt='', inputs_decoded='', raise_generate_gpu_exceptions=True, **kwargs):
1220
- try:
1221
- func(*args, **kwargs)
1222
- except torch.cuda.OutOfMemoryError as e:
1223
- print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1224
- flush=True)
1225
- if 'input_ids' in kwargs:
1226
- if kwargs['input_ids'] is not None:
1227
- kwargs['input_ids'].cpu()
1228
- kwargs['input_ids'] = None
1229
- traceback.print_exc()
1230
- clear_torch_cache()
1231
- return
1232
- except (Exception, RuntimeError) as e:
1233
- if 'Expected all tensors to be on the same device' in str(e) or \
1234
- 'expected scalar type Half but found Float' in str(e) or \
1235
- 'probability tensor contains either' in str(e) or \
1236
- 'cublasLt ran into an error!' in str(e) or \
1237
- 'mat1 and mat2 shapes cannot be multiplied' in str(e):
1238
- print(
1239
- "GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1240
- flush=True)
1241
- traceback.print_exc()
1242
- clear_torch_cache()
1243
- if raise_generate_gpu_exceptions:
1244
- raise
1245
- return
1246
- else:
1247
- clear_torch_cache()
1248
- if raise_generate_gpu_exceptions:
1249
- raise
1250
-
1251
-
1252
- def get_generate_params(model_lower, chat,
1253
- stream_output, show_examples,
1254
- prompt_type, temperature, top_p, top_k, num_beams,
1255
- max_new_tokens, min_new_tokens, early_stopping, max_time,
1256
- repetition_penalty, num_return_sequences,
1257
- do_sample, k, verbose):
1258
- use_defaults = False
1259
- use_default_examples = True
1260
- examples = []
1261
- task_info = f"{prompt_type}"
1262
- if model_lower:
1263
- print(f"Using Model {model_lower}", flush=True)
1264
- else:
1265
- print("No model defined yet", flush=True)
1266
-
1267
- min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
1268
- early_stopping = early_stopping if early_stopping is not None else False
1269
- max_time_defaults = 60 * 3
1270
- max_time = max_time if max_time is not None else max_time_defaults
1271
-
1272
- if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1273
- prompt_type = inv_prompt_type_to_model_lower[model_lower]
1274
- if verbose:
1275
- print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
1276
-
1277
- # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
1278
- if show_examples is None:
1279
- if chat:
1280
- show_examples = False
1281
- else:
1282
- show_examples = True
1283
-
1284
- summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
1285
- Philipp: Sure you can use the new Hugging Face Deep Learning Container.
1286
- Jeff: ok.
1287
- Jeff: and how can I get started?
1288
- Jeff: where can I find documentation?
1289
- Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
1290
-
1291
- if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
1292
- placeholder_instruction = summarize_example1
1293
- placeholder_input = ""
1294
- use_defaults = True
1295
- use_default_examples = False
1296
- examples += [
1297
- [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1298
- 1.0, 1,
1299
- False]]
1300
- task_info = "Summarization"
1301
- elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
1302
- placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
1303
- placeholder_input = ""
1304
- use_defaults = True
1305
- use_default_examples = True
1306
- task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
1307
- elif 'mbart-' in model_lower:
1308
- placeholder_instruction = "The girl has long hair."
1309
- placeholder_input = ""
1310
- use_defaults = True
1311
- use_default_examples = False
1312
- examples += [
1313
- [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1314
- 1.0, 1,
1315
- False]]
1316
- elif 'gpt2' in model_lower:
1317
- placeholder_instruction = "The sky is"
1318
- placeholder_input = ""
1319
- prompt_type = prompt_type or 'plain'
1320
- use_default_examples = True # some will be odd "continuations" but can be ok
1321
- examples += [
1322
- [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1323
- 1.0, 1,
1324
- False]]
1325
- task_info = "Auto-complete phrase, code, etc."
1326
- use_defaults = True
1327
- else:
1328
- if chat:
1329
- placeholder_instruction = "Enter a question or imperative."
1330
- else:
1331
- placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
1332
- placeholder_input = ""
1333
- if model_lower:
1334
- # default is plain, because might relly upon trust_remote_code to handle prompting
1335
- prompt_type = prompt_type or 'plain'
1336
- else:
1337
- prompt_type = ''
1338
- task_info = "No task"
1339
- if prompt_type == 'instruct':
1340
- task_info = "Answer question or follow imperative as instruction with optionally input."
1341
- elif prompt_type == 'plain':
1342
- task_info = "Auto-complete phrase, code, etc."
1343
- elif prompt_type == 'human_bot':
1344
- if chat:
1345
- task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
1346
- else:
1347
- task_info = "Ask question/imperative (input concatenated with instruction)"
1348
-
1349
- # revert to plain if still nothing
1350
- prompt_type = prompt_type or 'plain'
1351
- if use_defaults:
1352
- temperature = 1.0 if temperature is None else temperature
1353
- top_p = 1.0 if top_p is None else top_p
1354
- top_k = 40 if top_k is None else top_k
1355
- num_beams = num_beams or 1
1356
- max_new_tokens = max_new_tokens or 128
1357
- repetition_penalty = repetition_penalty or 1.07
1358
- num_return_sequences = min(num_beams, num_return_sequences or 1)
1359
- do_sample = False if do_sample is None else do_sample
1360
- else:
1361
- temperature = 0.1 if temperature is None else temperature
1362
- top_p = 0.75 if top_p is None else top_p
1363
- top_k = 40 if top_k is None else top_k
1364
- if chat:
1365
- num_beams = num_beams or 1
1366
- else:
1367
- num_beams = num_beams or 4
1368
- max_new_tokens = max_new_tokens or 256
1369
- repetition_penalty = repetition_penalty or 1.07
1370
- num_return_sequences = min(num_beams, num_return_sequences or 1)
1371
- do_sample = False if do_sample is None else do_sample
1372
- # doesn't include chat, instruction_nochat, iinput_nochat, added later
1373
- params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
1374
- early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
1375
-
1376
- if use_default_examples:
1377
- examples += [
1378
- ["Translate English to French", "Good morning"] + params_list,
1379
- ["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
1380
- ["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
1381
- [
1382
- "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
1383
- ''] + params_list,
1384
- ['Translate to German: My name is Arthur', ''] + params_list,
1385
- ["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
1386
- ['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
1387
- ''] + params_list,
1388
- ['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
1389
- ['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
1390
- ["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
1391
- [
1392
- "Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
1393
- ''] + params_list,
1394
- ['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
1395
- [
1396
- 'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
1397
- ''] + params_list,
1398
- ["""def area_of_rectangle(a: float, b: float):
1399
- \"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
1400
- ["""# a function in native python:
1401
- def mean(a):
1402
- return sum(a)/len(a)
1403
-
1404
- # the same function using numpy:
1405
- import numpy as np
1406
- def mean(a):""", ''] + params_list,
1407
- ["""X = np.random.randn(100, 100)
1408
- y = np.random.randint(0, 1, 100)
1409
-
1410
- # fit random forest classifier with 20 estimators""", ''] + params_list,
1411
- ]
1412
- # add summary example
1413
- examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else ''] + params_list]
1414
-
1415
- src_lang = "English"
1416
- tgt_lang = "Russian"
1417
-
1418
- # move to correct position
1419
- for example in examples:
1420
- example += [chat, '', '', 'Disabled', k, ['All']]
1421
- # adjust examples if non-chat mode
1422
- if not chat:
1423
- example[eval_func_param_names.index('instruction_nochat')] = example[
1424
- eval_func_param_names.index('instruction')]
1425
- example[eval_func_param_names.index('instruction')] = ''
1426
-
1427
- example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
1428
- example[eval_func_param_names.index('iinput')] = ''
1429
- assert len(example) == len(eval_func_param_names), "Wrong example: %s %s" % (
1430
- len(example), len(eval_func_param_names))
1431
-
1432
- return placeholder_instruction, placeholder_input, \
1433
- stream_output, show_examples, \
1434
- prompt_type, temperature, top_p, top_k, num_beams, \
1435
- max_new_tokens, min_new_tokens, early_stopping, max_time, \
1436
- repetition_penalty, num_return_sequences, \
1437
- do_sample, \
1438
- src_lang, tgt_lang, \
1439
- examples, \
1440
- task_info
1441
-
1442
-
1443
- def languages_covered():
1444
- # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
1445
- covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
1446
- covered = covered.split(', ')
1447
- covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
1448
- return covered
1449
-
1450
-
1451
- def get_context(chat_context, prompt_type):
1452
- if chat_context and prompt_type == 'human_bot':
1453
- context0 = """<bot>: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand.
1454
- <human>: I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed."""
1455
- else:
1456
- context0 = ''
1457
- return context0
1458
-
1459
-
1460
- def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len):
1461
- question = question[-cutoff_len:]
1462
- answer = answer[-cutoff_len:]
1463
-
1464
- inputs = stokenizer(question, answer,
1465
- return_tensors="pt",
1466
- truncation=True,
1467
- max_length=max_length_tokenize).to(smodel.device)
1468
- try:
1469
- score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
1470
- except torch.cuda.OutOfMemoryError as e:
1471
- print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
1472
- del inputs
1473
- traceback.print_exc()
1474
- clear_torch_cache()
1475
- return 'Response Score: GPU OOM'
1476
- except (Exception, RuntimeError) as e:
1477
- if 'Expected all tensors to be on the same device' in str(e) or \
1478
- 'expected scalar type Half but found Float' in str(e) or \
1479
- 'probability tensor contains either' in str(e) or \
1480
- 'cublasLt ran into an error!' in str(e):
1481
- print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
1482
- flush=True)
1483
- traceback.print_exc()
1484
- clear_torch_cache()
1485
- return 'Response Score: GPU Error'
1486
- else:
1487
- raise
1488
- os.environ['TOKENIZERS_PARALLELISM'] = 'true'
1489
- return score
1490
-
1491
-
1492
- def check_locals(**kwargs):
1493
- # ensure everything in evaluate is here
1494
- can_skip_because_locally_generated = [ # evaluate
1495
- 'instruction',
1496
- 'iinput',
1497
- 'context',
1498
- 'instruction_nochat',
1499
- 'iinput_nochat',
1500
- # get_model:
1501
- 'reward_type'
1502
- ]
1503
- for k in eval_func_param_names:
1504
- if k in can_skip_because_locally_generated:
1505
- continue
1506
- assert k in kwargs, "Missing %s" % k
1507
- for k in inputs_kwargs_list:
1508
- if k in can_skip_because_locally_generated:
1509
- continue
1510
- assert k in kwargs, "Missing %s" % k
1511
-
1512
- for k in list(inspect.signature(get_model).parameters):
1513
- if k in can_skip_because_locally_generated:
1514
- continue
1515
- assert k in kwargs, "Missing %s" % k
1516
-
1517
-
1518
- if __name__ == "__main__":
1519
- """
1520
- Examples:
1521
-
1522
- WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
1523
- python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
1524
- python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
1525
-
1526
- # generate without lora weights, no prompt
1527
- python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
1528
- python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
1529
-
1530
- python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
1531
- # OpenChatKit settings:
1532
- python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
1533
-
1534
- python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
1535
- python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
1536
- python generate.py --base_model='philschmid/bart-large-cnn-samsum'
1537
- python generate.py --base_model='philschmid/flan-t5-base-samsum'
1538
- python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
1539
-
1540
- python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
1541
-
1542
- must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
1543
- can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
1544
- python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
1545
-
1546
- python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
1547
- """
1548
- fire.Fire(main)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../generate.py
gpt4all_llm.py DELETED
@@ -1,255 +0,0 @@
1
- import inspect
2
- import os
3
- import sys
4
- from typing import Dict, Any, Optional, List
5
- from langchain.callbacks.manager import CallbackManagerForLLMRun
6
- from pydantic import root_validator
7
- from langchain.llms import gpt4all
8
- from dotenv import dotenv_values
9
-
10
-
11
- class FakeTokenizer:
12
-
13
- def encode(self, x, *args, **kwargs):
14
- return dict(input_ids=[x])
15
-
16
- def decode(self, x, *args, **kwargs):
17
- return x
18
-
19
- def __call__(self, x, *args, **kwargs):
20
- return self.encode(x, *args, **kwargs)
21
-
22
-
23
- def get_model_tokenizer_gpt4all(base_model, **kwargs):
24
- # defaults (some of these are generation parameters, so need to be passed in at generation time)
25
- model_kwargs = dict(n_threads=os.cpu_count() // 2,
26
- temp=kwargs.get('temperature', 0.2),
27
- top_p=kwargs.get('top_p', 0.75),
28
- top_k=kwargs.get('top_k', 40),
29
- n_ctx=2048 - 256)
30
- env_gpt4all_file = ".env_gpt4all"
31
- model_kwargs.update(dotenv_values(env_gpt4all_file))
32
-
33
- if base_model == "llama":
34
- if 'model_path_llama' not in model_kwargs:
35
- raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
36
- model_path = model_kwargs.pop('model_path_llama')
37
- # FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
38
- from llama_cpp import Llama
39
- # llama sets some things at init model time, not generation time
40
- func_names = list(inspect.signature(Llama.__init__).parameters)
41
- model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
42
- model_kwargs['n_ctx'] = int(model_kwargs['n_ctx'])
43
- model = Llama(model_path=model_path, **model_kwargs)
44
- elif base_model in "gpt4all_llama":
45
- if 'model_name_gpt4all_llama' not in model_kwargs and 'model_path_gpt4all_llama' not in model_kwargs:
46
- raise ValueError("No model_name_gpt4all_llama or model_path_gpt4all_llama in %s" % env_gpt4all_file)
47
- model_name = model_kwargs.pop('model_name_gpt4all_llama')
48
- model_type = 'llama'
49
- from gpt4all import GPT4All as GPT4AllModel
50
- model = GPT4AllModel(model_name=model_name, model_type=model_type)
51
- elif base_model in "gptj":
52
- if 'model_name_gptj' not in model_kwargs and 'model_path_gptj' not in model_kwargs:
53
- raise ValueError("No model_name_gpt4j or model_path_gpt4j in %s" % env_gpt4all_file)
54
- model_name = model_kwargs.pop('model_name_gptj')
55
- model_type = 'gptj'
56
- from gpt4all import GPT4All as GPT4AllModel
57
- model = GPT4AllModel(model_name=model_name, model_type=model_type)
58
- else:
59
- raise ValueError("No such base_model %s" % base_model)
60
- return model, FakeTokenizer(), 'cpu'
61
-
62
-
63
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
64
-
65
-
66
- class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
67
-
68
- def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
69
- """Run on new LLM token. Only available when streaming is enabled."""
70
- # streaming to std already occurs without this
71
- # sys.stdout.write(token)
72
- # sys.stdout.flush()
73
- pass
74
-
75
-
76
- def get_model_kwargs(env_kwargs, default_kwargs, cls):
77
- # default from class
78
- model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
79
- # from our defaults
80
- model_kwargs.update(default_kwargs)
81
- # from user defaults
82
- model_kwargs.update(env_kwargs)
83
- # ensure only valid keys
84
- func_names = list(inspect.signature(cls).parameters)
85
- model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
86
- return model_kwargs
87
-
88
-
89
- def get_llm_gpt4all(model_name,
90
- model=None,
91
- max_new_tokens=256,
92
- temperature=0.1,
93
- repetition_penalty=1.0,
94
- top_k=40,
95
- top_p=0.7,
96
- verbose=False):
97
- env_gpt4all_file = ".env_gpt4all"
98
- env_kwargs = dotenv_values(env_gpt4all_file)
99
- callbacks = [H2OStreamingStdOutCallbackHandler()]
100
- n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
101
- default_kwargs = dict(context_erase=0.5,
102
- n_batch=1,
103
- n_ctx=n_ctx,
104
- n_predict=max_new_tokens,
105
- repeat_last_n=64 if repetition_penalty != 1.0 else 0,
106
- repeat_penalty=repetition_penalty,
107
- temp=temperature,
108
- temperature=temperature,
109
- top_k=top_k,
110
- top_p=top_p,
111
- use_mlock=True,
112
- verbose=verbose)
113
- if model_name == 'llama':
114
- cls = H2OLlamaCpp
115
- model_path = env_kwargs.pop('model_path_llama') if model is None else model
116
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
117
- model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
118
- llm = cls(**model_kwargs)
119
- llm.client.verbose = verbose
120
- elif model_name == 'gpt4all_llama':
121
- cls = H2OGPT4All
122
- model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
123
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
124
- model_kwargs.update(dict(model=model_path, backend='llama', callbacks=callbacks))
125
- llm = cls(**model_kwargs)
126
- elif model_name == 'gptj':
127
- cls = H2OGPT4All
128
- model_path = env_kwargs.pop('model_path_gptj') if model is None else model
129
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
130
- model_kwargs.update(dict(model=model_path, backend='gptj', callbacks=callbacks))
131
- llm = cls(**model_kwargs)
132
- else:
133
- raise RuntimeError("No such model_name %s" % model_name)
134
- return llm
135
-
136
-
137
- class H2OGPT4All(gpt4all.GPT4All):
138
- model: Any
139
- """Path to the pre-trained GPT4All model file."""
140
-
141
- @root_validator()
142
- def validate_environment(cls, values: Dict) -> Dict:
143
- """Validate that the python package exists in the environment."""
144
- try:
145
- if isinstance(values["model"], str):
146
- from gpt4all import GPT4All as GPT4AllModel
147
-
148
- full_path = values["model"]
149
- model_path, delimiter, model_name = full_path.rpartition("/")
150
- model_path += delimiter
151
-
152
- values["client"] = GPT4AllModel(
153
- model_name=model_name,
154
- model_path=model_path or None,
155
- model_type=values["backend"],
156
- allow_download=False,
157
- )
158
- else:
159
- values["client"] = values["model"]
160
- values["backend"] = values["client"].model.model_type
161
-
162
- except ImportError:
163
- raise ValueError(
164
- "Could not import gpt4all python package. "
165
- "Please install it with `pip install gpt4all`."
166
- )
167
- return values
168
-
169
- def _call(
170
- self,
171
- prompt: str,
172
- stop: Optional[List[str]] = None,
173
- run_manager: Optional[CallbackManagerForLLMRun] = None,
174
- ) -> str:
175
- # Roughly 4 chars per token if natural language
176
- prompt = prompt[-self.n_ctx * 4:]
177
- verbose = False
178
- if verbose:
179
- print("_call prompt: %s" % prompt, flush=True)
180
- return super()._call(prompt, stop=stop, run_manager=run_manager)
181
-
182
-
183
- from langchain.llms import LlamaCpp
184
-
185
-
186
- class H2OLlamaCpp(LlamaCpp):
187
- model_path: Any
188
- """Path to the pre-trained GPT4All model file."""
189
-
190
- @root_validator()
191
- def validate_environment(cls, values: Dict) -> Dict:
192
- """Validate that llama-cpp-python library is installed."""
193
- if isinstance(values["model_path"], str):
194
- model_path = values["model_path"]
195
- model_param_names = [
196
- "lora_path",
197
- "lora_base",
198
- "n_ctx",
199
- "n_parts",
200
- "seed",
201
- "f16_kv",
202
- "logits_all",
203
- "vocab_only",
204
- "use_mlock",
205
- "n_threads",
206
- "n_batch",
207
- "use_mmap",
208
- "last_n_tokens_size",
209
- ]
210
- model_params = {k: values[k] for k in model_param_names}
211
- # For backwards compatibility, only include if non-null.
212
- if values["n_gpu_layers"] is not None:
213
- model_params["n_gpu_layers"] = values["n_gpu_layers"]
214
-
215
- try:
216
- from llama_cpp import Llama
217
-
218
- values["client"] = Llama(model_path, **model_params)
219
- except ImportError:
220
- raise ModuleNotFoundError(
221
- "Could not import llama-cpp-python library. "
222
- "Please install the llama-cpp-python library to "
223
- "use this embedding model: pip install llama-cpp-python"
224
- )
225
- except Exception as e:
226
- raise ValueError(
227
- f"Could not load Llama model from path: {model_path}. "
228
- f"Received error {e}"
229
- )
230
- else:
231
- values["client"] = values["model_path"]
232
- return values
233
-
234
- def _call(
235
- self,
236
- prompt: str,
237
- stop: Optional[List[str]] = None,
238
- run_manager: Optional[CallbackManagerForLLMRun] = None,
239
- ) -> str:
240
- verbose = False
241
- # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
242
- prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
243
- num_prompt_tokens = len(prompt_tokens)
244
- if num_prompt_tokens > self.n_ctx:
245
- # conservative by using int()
246
- chars_per_token = int(len(prompt) / num_prompt_tokens)
247
- prompt = prompt[-self.n_ctx * chars_per_token:]
248
- if verbose:
249
- print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
250
- prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
251
- num_prompt_tokens2 = len(prompt_tokens2)
252
- print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
253
- if verbose:
254
- print("_call prompt: %s" % prompt, flush=True)
255
- return super()._call(prompt, stop=stop, run_manager=run_manager)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gpt4all_llm.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../gpt4all_llm.py
gpt_langchain.py DELETED
@@ -1,1471 +0,0 @@
1
- import glob
2
- import inspect
3
- import os
4
- import pathlib
5
- import pickle
6
- import queue
7
- import shutil
8
- import subprocess
9
- import sys
10
- import tempfile
11
- import traceback
12
- import uuid
13
- import zipfile
14
- from collections import defaultdict
15
- from datetime import datetime
16
- from functools import reduce
17
- from operator import concat
18
-
19
- from joblib import Parallel, delayed
20
- from tqdm import tqdm
21
-
22
- from prompter import non_hf_types
23
- from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
24
- get_device, ProgressParallel, remove, hash_file
25
-
26
- import_matplotlib()
27
-
28
- import numpy as np
29
- import pandas as pd
30
- import requests
31
- from langchain.chains.qa_with_sources import load_qa_with_sources_chain
32
- # , GCSDirectoryLoader, GCSFileLoader
33
- # , OutlookMessageLoader # GPL3
34
- # ImageCaptionLoader, # use our own wrapper
35
- # ReadTheDocsLoader, # no special file, some path, so have to give as special option
36
- from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
37
- UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
38
- EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
39
- UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
40
- from langchain.text_splitter import RecursiveCharacterTextSplitter
41
- from langchain.chains.question_answering import load_qa_chain
42
- from langchain.docstore.document import Document
43
- from langchain import PromptTemplate
44
- from langchain.vectorstores import Chroma
45
-
46
-
47
- def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
48
- collection_name=None,
49
- hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
50
- if not sources:
51
- return None
52
- # get embedding model
53
- embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
54
- assert collection_name is not None or langchain_mode != 'notset'
55
- if collection_name is None:
56
- collection_name = langchain_mode.replace(' ', '_')
57
-
58
- # Create vector database
59
- if db_type == 'faiss':
60
- from langchain.vectorstores import FAISS
61
- db = FAISS.from_documents(sources, embedding)
62
-
63
- elif db_type == 'weaviate':
64
- import weaviate
65
- from weaviate.embedded import EmbeddedOptions
66
- from langchain.vectorstores import Weaviate
67
-
68
- # TODO: add support for connecting via docker compose
69
- client = weaviate.Client(
70
- embedded_options=EmbeddedOptions()
71
- )
72
- index_name = collection_name.capitalize()
73
- db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
74
- index_name=index_name)
75
-
76
- elif db_type == 'chroma':
77
- assert persist_directory is not None
78
- os.makedirs(persist_directory, exist_ok=True)
79
- db = Chroma.from_documents(documents=sources,
80
- embedding=embedding,
81
- persist_directory=persist_directory,
82
- collection_name=collection_name,
83
- anonymized_telemetry=False)
84
- db.persist()
85
- else:
86
- raise RuntimeError("No such db_type=%s" % db_type)
87
-
88
- return db
89
-
90
-
91
- def _get_unique_sources_in_weaviate(db):
92
- batch_size = 100
93
- id_source_list = []
94
- result = db._client.data_object.get(class_name=db._index_name, limit=batch_size)
95
-
96
- while result['objects']:
97
- id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']]
98
- last_id = id_source_list[-1][0]
99
- result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id)
100
-
101
- unique_sources = {source for _, source in id_source_list}
102
- return unique_sources
103
-
104
-
105
- def add_to_db(db, sources, db_type='faiss',
106
- avoid_dup_by_file=False,
107
- avoid_dup_by_content=True):
108
- num_new_sources = len(sources)
109
- if not sources:
110
- return db, num_new_sources, []
111
- if db_type == 'faiss':
112
- db.add_documents(sources)
113
- elif db_type == 'weaviate':
114
- # FIXME: only control by file name, not hash yet
115
- if avoid_dup_by_file or avoid_dup_by_content:
116
- unique_sources = _get_unique_sources_in_weaviate(db)
117
- sources = [x for x in sources if x.metadata['source'] not in unique_sources]
118
- num_new_sources = len(sources)
119
- if num_new_sources == 0:
120
- return db, num_new_sources, []
121
- db.add_documents(documents=sources)
122
- elif db_type == 'chroma':
123
- collection = db.get()
124
- # files we already have:
125
- metadata_files = set([x['source'] for x in collection['metadatas']])
126
- if avoid_dup_by_file:
127
- # Too weak in case file changed content, assume parent shouldn't pass true for this for now
128
- raise RuntimeError("Not desired code path")
129
- sources = [x for x in sources if x.metadata['source'] not in metadata_files]
130
- if avoid_dup_by_content:
131
- # look at hash, instead of page_content
132
- # migration: If no hash previously, avoid updating,
133
- # since don't know if need to update and may be expensive to redo all unhashed files
134
- metadata_hash_ids = set(
135
- [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
136
- # avoid sources with same hash
137
- sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
138
- # get new file names that match existing file names. delete existing files we are overridding
139
- dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
140
- print("Removing %s duplicate files from db because ingesting those as new documents" % len(
141
- dup_metadata_files), flush=True)
142
- client_collection = db._client.get_collection(name=db._collection.name)
143
- for dup_file in dup_metadata_files:
144
- dup_file_meta = dict(source=dup_file)
145
- try:
146
- client_collection.delete(where=dup_file_meta)
147
- except KeyError:
148
- pass
149
- num_new_sources = len(sources)
150
- if num_new_sources == 0:
151
- return db, num_new_sources, []
152
- db.add_documents(documents=sources)
153
- db.persist()
154
- else:
155
- raise RuntimeError("No such db_type=%s" % db_type)
156
-
157
- new_sources_metadata = [x.metadata for x in sources]
158
-
159
- return db, num_new_sources, new_sources_metadata
160
-
161
-
162
- def create_or_update_db(db_type, persist_directory, collection_name,
163
- sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model):
164
- if db_type == 'weaviate':
165
- import weaviate
166
- from weaviate.embedded import EmbeddedOptions
167
-
168
- # TODO: add support for connecting via docker compose
169
- client = weaviate.Client(
170
- embedded_options=EmbeddedOptions()
171
- )
172
- index_name = collection_name.replace(' ', '_').capitalize()
173
- if client.schema.exists(index_name) and not add_if_exists:
174
- client.schema.delete_class(index_name)
175
- if verbose:
176
- print("Removing %s" % index_name, flush=True)
177
- elif db_type == 'chroma':
178
- if not os.path.isdir(persist_directory) or not add_if_exists:
179
- if os.path.isdir(persist_directory):
180
- if verbose:
181
- print("Removing %s" % persist_directory, flush=True)
182
- remove(persist_directory)
183
- if verbose:
184
- print("Generating db", flush=True)
185
-
186
- if not add_if_exists:
187
- if verbose:
188
- print("Generating db", flush=True)
189
- else:
190
- if verbose:
191
- print("Loading and updating db", flush=True)
192
-
193
- db = get_db(sources,
194
- use_openai_embedding=use_openai_embedding,
195
- db_type=db_type,
196
- persist_directory=persist_directory,
197
- langchain_mode=collection_name,
198
- hf_embedding_model=hf_embedding_model)
199
-
200
- return db
201
-
202
-
203
- def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
204
- # Get embedding model
205
- if use_openai_embedding:
206
- assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
207
- from langchain.embeddings import OpenAIEmbeddings
208
- embedding = OpenAIEmbeddings()
209
- else:
210
- # to ensure can fork without deadlock
211
- from langchain.embeddings import HuggingFaceEmbeddings
212
-
213
- device, torch_dtype, context_class = get_device_dtype()
214
- model_kwargs = dict(device=device)
215
- embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
216
- return embedding
217
-
218
-
219
- def get_answer_from_sources(chain, sources, question):
220
- return chain(
221
- {
222
- "input_documents": sources,
223
- "question": question,
224
- },
225
- return_only_outputs=True,
226
- )["output_text"]
227
-
228
-
229
- def get_llm(use_openai_model=False, model_name=None, model=None,
230
- tokenizer=None, stream_output=False,
231
- max_new_tokens=256,
232
- temperature=0.1,
233
- repetition_penalty=1.0,
234
- top_k=40,
235
- top_p=0.7,
236
- prompt_type=None,
237
- prompter=None,
238
- verbose=False,
239
- ):
240
- if use_openai_model:
241
- from langchain.llms import OpenAI
242
- llm = OpenAI(temperature=0)
243
- model_name = 'openai'
244
- streamer = None
245
- prompt_type = 'plain'
246
- elif model_name in non_hf_types:
247
- from gpt4all_llm import get_llm_gpt4all
248
- llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
249
- temperature=temperature,
250
- repetition_penalty=repetition_penalty,
251
- top_k=top_k,
252
- top_p=top_p,
253
- verbose=verbose,
254
- )
255
- streamer = None
256
- prompt_type = 'plain'
257
- else:
258
- from transformers import AutoTokenizer, AutoModelForCausalLM
259
-
260
- if model is None:
261
- # only used if didn't pass model in
262
- assert model_name is None
263
- assert tokenizer is None
264
- prompt_type = 'human_bot'
265
- model_name = 'h2oai/h2ogpt-oasst1-512-12b'
266
- # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
267
- # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
268
- tokenizer = AutoTokenizer.from_pretrained(model_name)
269
- device, torch_dtype, context_class = get_device_dtype()
270
-
271
- with context_class(device):
272
- load_8bit = True
273
- # FIXME: for now not to spread across hetero GPUs
274
- # device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
275
- device_map = {"": 0} if device == 'cuda' else "auto"
276
- model = AutoModelForCausalLM.from_pretrained(model_name,
277
- device_map=device_map,
278
- torch_dtype=torch_dtype,
279
- load_in_8bit=load_8bit)
280
-
281
- max_max_tokens = tokenizer.model_max_length
282
- gen_kwargs = dict(max_new_tokens=max_new_tokens,
283
- return_full_text=True,
284
- early_stopping=False,
285
- handle_long_generation='hole')
286
-
287
- if stream_output:
288
- skip_prompt = False
289
- from generate import H2OTextIteratorStreamer
290
- decoder_kwargs = {}
291
- streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
292
- gen_kwargs.update(dict(streamer=streamer))
293
- else:
294
- streamer = None
295
-
296
- from h2oai_pipeline import H2OTextGenerationPipeline
297
- pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
298
- prompter=prompter,
299
- prompt_type=prompt_type,
300
- sanitize_bot_response=True,
301
- chat=False, stream_output=stream_output,
302
- tokenizer=tokenizer,
303
- max_input_tokens=max_max_tokens - max_new_tokens,
304
- **gen_kwargs)
305
- # pipe.task = "text-generation"
306
- # below makes it listen only to our prompt removal,
307
- # not built in prompt removal that is less general and not specific for our model
308
- pipe.task = "text2text-generation"
309
-
310
- from langchain.llms import HuggingFacePipeline
311
- llm = HuggingFacePipeline(pipeline=pipe)
312
- return llm, model_name, streamer, prompt_type
313
-
314
-
315
- def get_device_dtype():
316
- # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
317
- import torch
318
- n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
319
- device = 'cpu' if n_gpus == 0 else 'cuda'
320
- # from utils import NullContext
321
- # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class
322
- context_class = torch.device
323
- torch_dtype = torch.float16 if device == 'cuda' else torch.float32
324
- return device, torch_dtype, context_class
325
-
326
-
327
- def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
328
- """
329
- Get wikipedia data from online
330
- :param title:
331
- :param first_paragraph_only:
332
- :param text_limit:
333
- :param take_head:
334
- :return:
335
- """
336
- filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head)
337
- url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}"
338
- if first_paragraph_only:
339
- url += "&exintro=1"
340
- import json
341
- if not os.path.isfile(filename):
342
- data = requests.get(url).json()
343
- json.dump(data, open(filename, 'wt'))
344
- else:
345
- data = json.load(open(filename, "rt"))
346
- page_content = list(data["query"]["pages"].values())[0]["extract"]
347
- if take_head is not None and text_limit is not None:
348
- page_content = page_content[:text_limit] if take_head else page_content[:-text_limit]
349
- title_url = str(title).replace(' ', '_')
350
- return Document(
351
- page_content=page_content,
352
- metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"},
353
- )
354
-
355
-
356
- def get_wiki_sources(first_para=True, text_limit=None):
357
- """
358
- Get specific named sources from wikipedia
359
- :param first_para:
360
- :param text_limit:
361
- :return:
362
- """
363
- default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux']
364
- wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources))
365
- return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources]
366
-
367
-
368
- def get_github_docs(repo_owner, repo_name):
369
- """
370
- Access github from specific repo
371
- :param repo_owner:
372
- :param repo_name:
373
- :return:
374
- """
375
- with tempfile.TemporaryDirectory() as d:
376
- subprocess.check_call(
377
- f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .",
378
- cwd=d,
379
- shell=True,
380
- )
381
- git_sha = (
382
- subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d)
383
- .decode("utf-8")
384
- .strip()
385
- )
386
- repo_path = pathlib.Path(d)
387
- markdown_files = list(repo_path.glob("*/*.md")) + list(
388
- repo_path.glob("*/*.mdx")
389
- )
390
- for markdown_file in markdown_files:
391
- with open(markdown_file, "r") as f:
392
- relative_path = markdown_file.relative_to(repo_path)
393
- github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}"
394
- yield Document(page_content=f.read(), metadata={"source": github_url})
395
-
396
-
397
- def get_dai_pickle(dest="."):
398
- from huggingface_hub import hf_hub_download
399
- # True for case when locally already logged in with correct token, so don't have to set key
400
- token = os.getenv('HUGGINGFACE_API_TOKEN', True)
401
- path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset')
402
- shutil.copy(path_to_zip_file, dest)
403
-
404
-
405
- def get_dai_docs(from_hf=False, get_pickle=True):
406
- """
407
- Consume DAI documentation, or consume from public pickle
408
- :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain
409
- :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF
410
- :return:
411
- """
412
- import pickle
413
-
414
- if get_pickle:
415
- get_dai_pickle()
416
-
417
- dai_store = 'dai_docs.pickle'
418
- dst = "working_dir_docs"
419
- if not os.path.isfile(dai_store):
420
- from create_data import setup_dai_docs
421
- dst = setup_dai_docs(dst=dst, from_hf=from_hf)
422
-
423
- import glob
424
- files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
425
-
426
- basedir = os.path.abspath(os.getcwd())
427
- from create_data import rst_to_outputs
428
- new_outputs = rst_to_outputs(files)
429
- os.chdir(basedir)
430
-
431
- pickle.dump(new_outputs, open(dai_store, 'wb'))
432
- else:
433
- new_outputs = pickle.load(open(dai_store, 'rb'))
434
-
435
- sources = []
436
- for line, file in new_outputs:
437
- # gradio requires any linked file to be with app.py
438
- sym_src = os.path.abspath(os.path.join(dst, file))
439
- sym_dst = os.path.abspath(os.path.join(os.getcwd(), file))
440
- if os.path.lexists(sym_dst):
441
- os.remove(sym_dst)
442
- os.symlink(sym_src, sym_dst)
443
- itm = Document(page_content=line, metadata={"source": file})
444
- # NOTE: yield has issues when going into db, loses metadata
445
- # yield itm
446
- sources.append(itm)
447
- return sources
448
-
449
-
450
- import distutils.spawn
451
-
452
- have_tesseract = distutils.spawn.find_executable("tesseract")
453
- have_libreoffice = distutils.spawn.find_executable("libreoffice")
454
-
455
- import pkg_resources
456
-
457
- try:
458
- assert pkg_resources.get_distribution('arxiv') is not None
459
- assert pkg_resources.get_distribution('pymupdf') is not None
460
- have_arxiv = True
461
- except (pkg_resources.DistributionNotFound, AssertionError):
462
- have_arxiv = False
463
-
464
- try:
465
- assert pkg_resources.get_distribution('pymupdf') is not None
466
- have_pymupdf = True
467
- except (pkg_resources.DistributionNotFound, AssertionError):
468
- have_pymupdf = False
469
-
470
- image_types = ["png", "jpg", "jpeg"]
471
- non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
472
- "md", "html",
473
- "enex", "eml", "epub", "odt", "pptx", "ppt",
474
- "zip", "urls",
475
- ]
476
- # "msg", GPL3
477
-
478
- if have_libreoffice:
479
- non_image_types.extend(["docx", "doc"])
480
-
481
- file_types = non_image_types + image_types
482
-
483
-
484
- def add_meta(docs1, file):
485
- file_extension = pathlib.Path(file).suffix
486
- hashid = hash_file(file)
487
- if not isinstance(docs1, list):
488
- docs1 = [docs1]
489
- [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
490
-
491
-
492
- def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
493
- is_url=False, is_txt=False,
494
- enable_captions=True,
495
- captions_model=None,
496
- enable_ocr=False, caption_loader=None,
497
- headsize=50):
498
- if file is None:
499
- if fail_any_exception:
500
- raise RuntimeError("Unexpected None file")
501
- else:
502
- return []
503
- doc1 = [] # in case no support, or disabled support
504
- if base_path is None and not is_txt and not is_url:
505
- # then assume want to persist but don't care which path used
506
- # can't be in base_path
507
- dir_name = os.path.dirname(file)
508
- base_name = os.path.basename(file)
509
- # if from gradio, will have its own temp uuid too, but that's ok
510
- base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
511
- base_path = os.path.join(dir_name, base_name)
512
- if is_url:
513
- if file.lower().startswith('arxiv:'):
514
- query = file.lower().split('arxiv:')
515
- if len(query) == 2 and have_arxiv:
516
- query = query[1]
517
- docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load()
518
- # ensure string, sometimes None
519
- [[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1]
520
- query_url = f"https://arxiv.org/abs/{query}"
521
- [x.metadata.update(
522
- dict(source=x.metadata.get('entry_id', query_url), query=query_url,
523
- input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in
524
- docs1]
525
- else:
526
- docs1 = []
527
- else:
528
- docs1 = UnstructuredURLLoader(urls=[file]).load()
529
- [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
530
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
531
- elif is_txt:
532
- base_path = "user_paste"
533
- source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
534
- makedirs(os.path.dirname(source_file), exist_ok=True)
535
- with open(source_file, "wt") as f:
536
- f.write(file)
537
- metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
538
- doc1 = Document(page_content=file, metadata=metadata)
539
- elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
540
- docs1 = UnstructuredHTMLLoader(file_path=file).load()
541
- add_meta(docs1, file)
542
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
543
- elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
544
- docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
545
- add_meta(docs1, file)
546
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
547
- elif file.lower().endswith('.odt'):
548
- docs1 = UnstructuredODTLoader(file_path=file).load()
549
- add_meta(docs1, file)
550
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
551
- elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
552
- docs1 = UnstructuredPowerPointLoader(file_path=file).load()
553
- add_meta(docs1, file)
554
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
555
- elif file.lower().endswith('.txt'):
556
- # use UnstructuredFileLoader ?
557
- docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
558
- # makes just one, but big one
559
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
560
- add_meta(doc1, file)
561
- elif file.lower().endswith('.rtf'):
562
- docs1 = UnstructuredRTFLoader(file).load()
563
- add_meta(docs1, file)
564
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
565
- elif file.lower().endswith('.md'):
566
- docs1 = UnstructuredMarkdownLoader(file).load()
567
- add_meta(docs1, file)
568
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
569
- elif file.lower().endswith('.enex'):
570
- docs1 = EverNoteLoader(file).load()
571
- add_meta(doc1, file)
572
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
573
- elif file.lower().endswith('.epub'):
574
- docs1 = UnstructuredEPubLoader(file).load()
575
- add_meta(docs1, file)
576
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
577
- elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
578
- docs1 = []
579
- if have_tesseract and enable_ocr:
580
- # OCR, somewhat works, but not great
581
- docs1.extend(UnstructuredImageLoader(file).load())
582
- add_meta(docs1, file)
583
- if enable_captions:
584
- # BLIP
585
- if caption_loader is not None and not isinstance(caption_loader, (str, bool)):
586
- # assumes didn't fork into this process with joblib, else can deadlock
587
- caption_loader.set_image_paths([file])
588
- docs1c = caption_loader.load()
589
- add_meta(docs1c, file)
590
- [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c]
591
- docs1.extend(docs1c)
592
- else:
593
- from image_captions import H2OImageCaptionLoader
594
- caption_loader = H2OImageCaptionLoader(caption_gpu=caption_loader == 'gpu',
595
- blip_model=captions_model,
596
- blip_processor=captions_model)
597
- caption_loader.set_image_paths([file])
598
- docs1c = caption_loader.load()
599
- add_meta(docs1c, file)
600
- [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c]
601
- docs1.extend(docs1c)
602
- for doci in docs1:
603
- doci.metadata['source'] = doci.metadata['image_path']
604
- doci.metadata['hash'] = hash_file(doci.metadata['source'])
605
- if docs1:
606
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
607
- elif file.lower().endswith('.msg'):
608
- raise RuntimeError("Not supported, GPL3 license")
609
- # docs1 = OutlookMessageLoader(file).load()
610
- # docs1[0].metadata['source'] = file
611
- elif file.lower().endswith('.eml'):
612
- try:
613
- docs1 = UnstructuredEmailLoader(file).load()
614
- add_meta(docs1, file)
615
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
616
- except ValueError as e:
617
- if 'text/html content not found in email' in str(e):
618
- # e.g. plain/text dict key exists, but not
619
- # doc1 = TextLoader(file, encoding="utf8").load()
620
- docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
621
- add_meta(docs1, file)
622
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
623
- else:
624
- raise
625
- # elif file.lower().endswith('.gcsdir'):
626
- # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
627
- # elif file.lower().endswith('.gcsfile'):
628
- # doc1 = GCSFileLoader(project_name, bucket, blob).load()
629
- elif file.lower().endswith('.rst'):
630
- with open(file, "r") as f:
631
- doc1 = Document(page_content=f.read(), metadata={"source": file})
632
- add_meta(doc1, file)
633
- elif file.lower().endswith('.pdf'):
634
- env_gpt4all_file = ".env_gpt4all"
635
- from dotenv import dotenv_values
636
- env_kwargs = dotenv_values(env_gpt4all_file)
637
- pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
638
- if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
639
- # GPL, only use if installed
640
- from langchain.document_loaders import PyMuPDFLoader
641
- doc1 = PyMuPDFLoader(file).load_and_split()
642
- else:
643
- # open-source fallback
644
- doc1 = PyPDFLoader(file).load_and_split()
645
- # Some PDFs return nothing or junk from PDFMinerLoader
646
- add_meta(doc1, file)
647
- elif file.lower().endswith('.csv'):
648
- doc1 = CSVLoader(file).load()
649
- add_meta(doc1, file)
650
- elif file.lower().endswith('.py'):
651
- doc1 = PythonLoader(file).load()
652
- add_meta(doc1, file)
653
- elif file.lower().endswith('.toml'):
654
- doc1 = TomlLoader(file).load()
655
- add_meta(doc1, file)
656
- elif file.lower().endswith('.urls'):
657
- with open(file, "r") as f:
658
- docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
659
- add_meta(docs1, file)
660
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
661
- elif file.lower().endswith('.zip'):
662
- with zipfile.ZipFile(file, 'r') as zip_ref:
663
- # don't put into temporary path, since want to keep references to docs inside zip
664
- # so just extract in path where
665
- zip_ref.extractall(base_path)
666
- # recurse
667
- doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception)
668
- else:
669
- raise RuntimeError("No file handler for %s" % os.path.basename(file))
670
-
671
- # allow doc1 to be list or not. If not list, did not chunk yet, so chunk now
672
- # if list of length one, don't trust and chunk it
673
- if not isinstance(doc1, list):
674
- if chunk:
675
- docs = chunk_sources([doc1], chunk_size=chunk_size)
676
- else:
677
- docs = [doc1]
678
- elif isinstance(doc1, list) and len(doc1) == 1:
679
- if chunk:
680
- docs = chunk_sources(doc1, chunk_size=chunk_size)
681
- else:
682
- docs = doc1
683
- else:
684
- docs = doc1
685
-
686
- assert isinstance(docs, list)
687
- return docs
688
-
689
-
690
- def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True, chunk=True, chunk_size=512,
691
- is_url=False, is_txt=False,
692
- enable_captions=True,
693
- captions_model=None,
694
- enable_ocr=False, caption_loader=None):
695
- if verbose:
696
- if is_url:
697
- print("Ingesting URL: %s" % file, flush=True)
698
- elif is_txt:
699
- print("Ingesting Text: %s" % file, flush=True)
700
- else:
701
- print("Ingesting file: %s" % file, flush=True)
702
- res = None
703
- try:
704
- # don't pass base_path=path, would infinitely recurse
705
- res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
706
- chunk=chunk, chunk_size=chunk_size,
707
- is_url=is_url, is_txt=is_txt,
708
- enable_captions=enable_captions,
709
- captions_model=captions_model,
710
- enable_ocr=enable_ocr,
711
- caption_loader=caption_loader)
712
- except BaseException as e:
713
- print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
714
- if fail_any_exception:
715
- raise
716
- else:
717
- exception_doc = Document(
718
- page_content='',
719
- metadata={"source": file, "exception": str(e), "traceback": traceback.format_exc()})
720
- res = [exception_doc]
721
- if return_file:
722
- base_tmp = "temp_path_to_doc1"
723
- if not os.path.isdir(base_tmp):
724
- os.makedirs(base_tmp, exist_ok=True)
725
- filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
726
- with open(filename, 'wb') as f:
727
- pickle.dump(res, f)
728
- return filename
729
- return res
730
-
731
-
732
- def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1,
733
- chunk=True, chunk_size=512,
734
- url=None, text=None,
735
- enable_captions=True,
736
- captions_model=None,
737
- caption_loader=None,
738
- enable_ocr=False,
739
- existing_files=[],
740
- existing_hash_ids={},
741
- ):
742
- globs_image_types = []
743
- globs_non_image_types = []
744
- if not path_or_paths and not url and not text:
745
- return []
746
- elif url:
747
- globs_non_image_types = [url]
748
- elif text:
749
- globs_non_image_types = [text]
750
- elif isinstance(path_or_paths, str):
751
- # single path, only consume allowed files
752
- path = path_or_paths
753
- # Below globs should match patterns in file_to_doc()
754
- [globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
755
- for ftype in image_types]
756
- [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
757
- for ftype in non_image_types]
758
- else:
759
- # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
760
- assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths)
761
- # reform out of allowed types
762
- globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
763
- # could do below:
764
- # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types])
765
- # But instead, allow fail so can collect unsupported too
766
- set_globs_image_types = set(globs_image_types)
767
- globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types])
768
-
769
- # filter out any files to skip (e.g. if already processed them)
770
- # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[]
771
- assert not existing_files, "DEV: assume not using this approach"
772
- if existing_files:
773
- set_skip_files = set(existing_files)
774
- globs_image_types = [x for x in globs_image_types if x not in set_skip_files]
775
- globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files]
776
- if existing_hash_ids:
777
- # assume consistent with add_meta() use of hash_file(file)
778
- # also assume consistent with get_existing_hash_ids for dict creation
779
- # assume hashable values
780
- existing_hash_ids_set = set(existing_hash_ids.items())
781
- hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items())
782
- hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items())
783
- # don't use symmetric diff. If file is gone, ignore and don't remove or something
784
- # just consider existing files (key) having new hash or not (value)
785
- new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys())
786
- new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys())
787
- globs_image_types = [x for x in globs_image_types if x in new_files_image]
788
- globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image]
789
-
790
- # could use generator, but messes up metadata handling in recursive case
791
- if caption_loader and not isinstance(caption_loader, (bool, str)) and \
792
- caption_loader.device != 'cpu' or \
793
- get_device() == 'cuda':
794
- # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context
795
- n_jobs_image = 1
796
- else:
797
- n_jobs_image = n_jobs
798
-
799
- return_file = True # local choice
800
- is_url = url is not None
801
- is_txt = text is not None
802
- kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
803
- return_file=return_file,
804
- chunk=chunk, chunk_size=chunk_size,
805
- is_url=is_url,
806
- is_txt=is_txt,
807
- enable_captions=enable_captions,
808
- captions_model=captions_model,
809
- caption_loader=caption_loader,
810
- enable_ocr=enable_ocr,
811
- )
812
-
813
- if n_jobs != 1 and len(globs_non_image_types) > 1:
814
- # avoid nesting, e.g. upload 1 zip and then inside many files
815
- # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
816
- documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
817
- delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types
818
- )
819
- else:
820
- documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_non_image_types)]
821
-
822
- # do images separately since can't fork after cuda in parent, so can't be parallel
823
- if n_jobs_image != 1 and len(globs_image_types) > 1:
824
- # avoid nesting, e.g. upload 1 zip and then inside many files
825
- # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
826
- image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
827
- delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types
828
- )
829
- else:
830
- image_documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_image_types)]
831
-
832
- # add image docs in
833
- documents += image_documents
834
-
835
- if return_file:
836
- # then documents really are files
837
- files = documents.copy()
838
- documents = []
839
- for fil in files:
840
- with open(fil, 'rb') as f:
841
- documents.extend(pickle.load(f))
842
- # remove temp pickle
843
- os.remove(fil)
844
- else:
845
- documents = reduce(concat, documents)
846
- return documents
847
-
848
-
849
- def prep_langchain(persist_directory,
850
- load_db_if_exists,
851
- db_type, use_openai_embedding, langchain_mode, user_path,
852
- hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
853
- """
854
- do prep first time, involving downloads
855
- # FIXME: Add github caching then add here
856
- :return:
857
- """
858
- assert langchain_mode not in ['MyData'], "Should not prep scratch data"
859
-
860
- db_dir_exists = os.path.isdir(persist_directory)
861
-
862
- if db_dir_exists and user_path is None:
863
- print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
864
- db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
865
- hf_embedding_model)
866
- else:
867
- if db_dir_exists and user_path is not None:
868
- print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
869
- persist_directory, user_path), flush=True)
870
- elif not db_dir_exists:
871
- print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
872
- db = None
873
- if langchain_mode in ['All', 'DriverlessAI docs']:
874
- # FIXME: Could also just use dai_docs.pickle directly and upload that
875
- get_dai_docs(from_hf=True)
876
-
877
- if langchain_mode in ['All', 'wiki']:
878
- get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit'])
879
-
880
- langchain_kwargs = kwargs_make_db.copy()
881
- langchain_kwargs.update(locals())
882
- db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs)
883
-
884
- return db
885
-
886
-
887
- import posthog
888
-
889
- posthog.disabled = True
890
-
891
-
892
- class FakeConsumer(object):
893
- def __init__(self, *args, **kwargs):
894
- pass
895
-
896
- def run(self):
897
- pass
898
-
899
- def pause(self):
900
- pass
901
-
902
- def upload(self):
903
- pass
904
-
905
- def next(self):
906
- pass
907
-
908
- def request(self, batch):
909
- pass
910
-
911
-
912
- posthog.Consumer = FakeConsumer
913
-
914
-
915
- def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
916
- hf_embedding_model):
917
- if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
918
- os.path.join(persist_directory, 'index')):
919
- print("DO Loading db: %s" % langchain_mode, flush=True)
920
- embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
921
- from chromadb.config import Settings
922
- client_settings = Settings(anonymized_telemetry=False,
923
- chroma_db_impl="duckdb+parquet",
924
- persist_directory=persist_directory)
925
- db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
926
- collection_name=langchain_mode.replace(' ', '_'),
927
- client_settings=client_settings)
928
- print("DONE Loading db: %s" % langchain_mode, flush=True)
929
- return db
930
- return None
931
-
932
-
933
- def make_db(**langchain_kwargs):
934
- func_names = list(inspect.signature(_make_db).parameters)
935
- missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
936
- defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()}
937
- for k in missing_kwargs:
938
- if k in defaults_db:
939
- langchain_kwargs[k] = defaults_db[k]
940
- # final check for missing
941
- missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
942
- assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
943
- # only keep actual used
944
- langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
945
- return _make_db(**langchain_kwargs)
946
-
947
-
948
- def _make_db(use_openai_embedding=False,
949
- hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
950
- first_para=False, text_limit=None, chunk=False, chunk_size=1024,
951
- langchain_mode=None,
952
- user_path=None,
953
- db_type='faiss',
954
- load_db_if_exists=True,
955
- db=None,
956
- n_jobs=-1,
957
- verbose=False):
958
- persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
959
- if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
960
- os.path.join(persist_directory, 'index')):
961
- assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
962
- print("Loading existing db", flush=True)
963
- embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
964
- from chromadb.config import Settings
965
- client_settings = Settings(anonymized_telemetry=False,
966
- chroma_db_impl="duckdb+parquet",
967
- persist_directory=persist_directory)
968
- db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
969
- collection_name=langchain_mode.replace(' ', '_'),
970
- client_settings=client_settings)
971
- sources = []
972
- if not db and langchain_mode not in ['MyData'] or \
973
- user_path is not None and \
974
- langchain_mode in ['UserData']:
975
- # Should not make MyData db this way, why avoided, only upload from UI
976
- assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
977
- if verbose:
978
- if langchain_mode in ['UserData']:
979
- if user_path is not None:
980
- print("Checking if changed or new sources in %s, and generating sources them" % user_path,
981
- flush=True)
982
- elif db is None:
983
- print("user_path not passed and no db, no sources", flush=True)
984
- else:
985
- print("user_path not passed, using only existing db, no new sources", flush=True)
986
- else:
987
- print("Generating %s sources" % langchain_mode, flush=True)
988
- if langchain_mode in ['wiki_full', 'All', "'All'"]:
989
- from read_wiki_full import get_all_documents
990
- small_test = None
991
- print("Generating new wiki", flush=True)
992
- sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
993
- print("Got new wiki", flush=True)
994
- if chunk:
995
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
996
- print("Chunked new wiki", flush=True)
997
- sources.extend(sources1)
998
- if langchain_mode in ['wiki', 'All', "'All'"]:
999
- sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
1000
- if chunk:
1001
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
1002
- sources.extend(sources1)
1003
- if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
1004
- # sources = get_github_docs("dagster-io", "dagster")
1005
- sources1 = get_github_docs("h2oai", "h2ogpt")
1006
- # FIXME: always chunk for now
1007
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
1008
- sources.extend(sources1)
1009
- if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
1010
- sources1 = get_dai_docs(from_hf=True)
1011
- if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
1012
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
1013
- sources.extend(sources1)
1014
- if langchain_mode in ['All', 'UserData']:
1015
- if user_path:
1016
- if db is not None:
1017
- # NOTE: Ignore file names for now, only go by hash ids
1018
- # existing_files = get_existing_files(db)
1019
- existing_files = []
1020
- existing_hash_ids = get_existing_hash_ids(db)
1021
- else:
1022
- # pretend no existing files so won't filter
1023
- existing_files = []
1024
- existing_hash_ids = []
1025
- # chunk internally for speed over multiple docs
1026
- sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1027
- existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1028
- new_metadata_sources = set([x.metadata['source'] for x in sources1])
1029
- if new_metadata_sources:
1030
- print("Loaded %s new files as sources to add to UserData" % len(new_metadata_sources), flush=True)
1031
- if verbose:
1032
- print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
1033
- sources.extend(sources1)
1034
- print("Loaded %s sources for potentially adding to UserData" % len(sources), flush=True)
1035
- else:
1036
- print("Chose UserData but user_path is empty/None", flush=True)
1037
- if False and langchain_mode in ['urls', 'All', "'All'"]:
1038
- # from langchain.document_loaders import UnstructuredURLLoader
1039
- # loader = UnstructuredURLLoader(urls=urls)
1040
- urls = ["https://www.birdsongsf.com/who-we-are/"]
1041
- from langchain.document_loaders import PlaywrightURLLoader
1042
- loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
1043
- sources1 = loader.load()
1044
- sources.extend(sources1)
1045
- if not sources:
1046
- if verbose:
1047
- if db is not None:
1048
- print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True)
1049
- else:
1050
- print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True)
1051
- return db, 0, []
1052
- if verbose:
1053
- if db is not None:
1054
- print("Generating db", flush=True)
1055
- else:
1056
- print("Adding to db", flush=True)
1057
- if not db:
1058
- if sources:
1059
- db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
1060
- persist_directory=persist_directory, langchain_mode=langchain_mode,
1061
- hf_embedding_model=hf_embedding_model)
1062
- if verbose:
1063
- print("Generated db", flush=True)
1064
- else:
1065
- print("Did not generate db since no sources", flush=True)
1066
- new_sources_metadata = [x.metadata for x in sources]
1067
- elif user_path is not None and langchain_mode in ['UserData']:
1068
- print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1069
- db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type)
1070
- print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
1071
- else:
1072
- new_sources_metadata = [x.metadata for x in sources]
1073
-
1074
- return db, len(new_sources_metadata), new_sources_metadata
1075
-
1076
-
1077
- def get_existing_files(db):
1078
- collection = db.get()
1079
- metadata_sources = set([x['source'] for x in collection['metadatas']])
1080
- return metadata_sources
1081
-
1082
-
1083
- def get_existing_hash_ids(db):
1084
- collection = db.get()
1085
- # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
1086
- metadata_hash_ids = {x['source']: x.get('hashid') for x in collection['metadatas']}
1087
- return metadata_hash_ids
1088
-
1089
-
1090
- source_prefix = "Sources [Score | Link]:"
1091
- source_postfix = "End Sources<p>"
1092
-
1093
-
1094
- def run_qa_db(**kwargs):
1095
- func_names = list(inspect.signature(_run_qa_db).parameters)
1096
- # hard-coded defaults
1097
- kwargs['answer_with_sources'] = True
1098
- kwargs['sanitize_bot_response'] = True
1099
- kwargs['show_rank'] = False
1100
- missing_kwargs = [x for x in func_names if x not in kwargs]
1101
- assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1102
- # only keep actual used
1103
- kwargs = {k: v for k, v in kwargs.items() if k in func_names}
1104
- return _run_qa_db(**kwargs)
1105
-
1106
-
1107
- def _run_qa_db(query=None,
1108
- use_openai_model=False, use_openai_embedding=False,
1109
- first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
1110
- user_path=None,
1111
- detect_user_path_changes_every_query=False,
1112
- db_type='faiss',
1113
- model_name=None, model=None, tokenizer=None,
1114
- hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1115
- stream_output=False,
1116
- prompter=None,
1117
- prompt_type=None,
1118
- answer_with_sources=True,
1119
- cut_distanct=1.1,
1120
- sanitize_bot_response=True,
1121
- show_rank=False,
1122
- load_db_if_exists=False,
1123
- db=None,
1124
- max_new_tokens=256,
1125
- temperature=0.1,
1126
- repetition_penalty=1.0,
1127
- top_k=40,
1128
- top_p=0.7,
1129
- langchain_mode=None,
1130
- document_choice=['All'],
1131
- n_jobs=-1,
1132
- verbose=False,
1133
- cli=False):
1134
- """
1135
-
1136
- :param query:
1137
- :param use_openai_model:
1138
- :param use_openai_embedding:
1139
- :param first_para:
1140
- :param text_limit:
1141
- :param k:
1142
- :param chunk:
1143
- :param chunk_size:
1144
- :param user_path: user path to glob recursively from
1145
- :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
1146
- :param model_name: model name, used to switch behaviors
1147
- :param model: pre-initialized model, else will make new one
1148
- :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
1149
- :param answer_with_sources
1150
- :return:
1151
- """
1152
- assert query is not None
1153
- assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
1154
- if prompter is not None:
1155
- prompt_type = prompter.prompt_type
1156
- if model is not None:
1157
- assert prompt_type is not None
1158
- llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1159
- model=model, tokenizer=tokenizer,
1160
- stream_output=stream_output,
1161
- max_new_tokens=max_new_tokens,
1162
- temperature=temperature,
1163
- repetition_penalty=repetition_penalty,
1164
- top_k=top_k,
1165
- top_p=top_p,
1166
- prompt_type=prompt_type,
1167
- prompter=prompter,
1168
- verbose=verbose,
1169
- )
1170
-
1171
- if model_name in non_hf_types:
1172
- # FIXME: for now, streams to stdout/stderr currently
1173
- stream_output = False
1174
-
1175
- use_context = False
1176
- scores = []
1177
- chain = None
1178
-
1179
- func_names = list(inspect.signature(get_similarity_chain).parameters)
1180
- sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1181
- missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1182
- assert not missing_kwargs, "Missing: %s" % missing_kwargs
1183
- docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
1184
- if len(document_choice) > 0 and document_choice[0] == 'Only':
1185
- formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
1186
- yield formatted_doc_chunks, ''
1187
- return
1188
- if chain is None and model_name not in non_hf_types:
1189
- # can only return if HF type
1190
- return
1191
-
1192
- if stream_output:
1193
- answer = None
1194
- assert streamer is not None
1195
- import queue
1196
- bucket = queue.Queue()
1197
- thread = EThread(target=chain, streamer=streamer, bucket=bucket)
1198
- thread.start()
1199
- outputs = ""
1200
- prompt = None # FIXME
1201
- try:
1202
- for new_text in streamer:
1203
- # print("new_text: %s" % new_text, flush=True)
1204
- if bucket.qsize() > 0 or thread.exc:
1205
- thread.join()
1206
- outputs += new_text
1207
- if prompter: # and False: # FIXME: pipeline can already use prompter
1208
- output1 = prompter.get_response(outputs, prompt=prompt,
1209
- sanitize_bot_response=sanitize_bot_response)
1210
- yield output1, ''
1211
- else:
1212
- yield outputs, ''
1213
- except BaseException:
1214
- # if any exception, raise that exception if was from thread, first
1215
- if thread.exc:
1216
- raise thread.exc
1217
- raise
1218
- finally:
1219
- # in case no exception and didn't join with thread yet, then join
1220
- if not thread.exc:
1221
- answer = thread.join()
1222
- # in case raise StopIteration or broke queue loop in streamer, but still have exception
1223
- if thread.exc:
1224
- raise thread.exc
1225
- # FIXME: answer is not string outputs from streamer. How to get actual final output?
1226
- # answer = outputs
1227
- else:
1228
- answer = chain()
1229
-
1230
- if not use_context:
1231
- ret = answer['output_text']
1232
- extra = ''
1233
- yield ret, extra
1234
- elif answer is not None:
1235
- ret, extra = get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=verbose)
1236
- yield ret, extra
1237
- return
1238
-
1239
-
1240
- def get_similarity_chain(query=None,
1241
- use_openai_model=False, use_openai_embedding=False,
1242
- first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
1243
- user_path=None,
1244
- detect_user_path_changes_every_query=False,
1245
- db_type='faiss',
1246
- model_name=None,
1247
- hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1248
- prompt_type=None,
1249
- cut_distanct=1.1,
1250
- load_db_if_exists=False,
1251
- db=None,
1252
- langchain_mode=None,
1253
- document_choice=['All'],
1254
- n_jobs=-1,
1255
- # beyond run_db_query:
1256
- llm=None,
1257
- verbose=False,
1258
- ):
1259
- # determine whether use of context out of docs is planned
1260
- if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1261
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
1262
- use_context = False
1263
- else:
1264
- use_context = True
1265
- else:
1266
- use_context = True
1267
-
1268
- # https://github.com/hwchase17/langchain/issues/1946
1269
- # FIXME: Seems to way to get size of chroma db to limit k to avoid
1270
- # Chroma collection MyData contains fewer than 4 elements.
1271
- # type logger error
1272
- k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
1273
-
1274
- # FIXME: For All just go over all dbs instead of a separate db for All
1275
- if not detect_user_path_changes_every_query and db is not None:
1276
- # avoid looking at user_path during similarity search db handling,
1277
- # if already have db and not updating from user_path every query
1278
- # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
1279
- user_path = None
1280
- db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
1281
- hf_embedding_model=hf_embedding_model,
1282
- first_para=first_para, text_limit=text_limit, chunk=chunk,
1283
- chunk_size=chunk_size,
1284
- langchain_mode=langchain_mode,
1285
- user_path=user_path,
1286
- db_type=db_type,
1287
- load_db_if_exists=load_db_if_exists,
1288
- db=db,
1289
- n_jobs=n_jobs,
1290
- verbose=verbose)
1291
-
1292
- if db and use_context:
1293
- if isinstance(document_choice, str):
1294
- # support string as well
1295
- document_choice = [document_choice]
1296
- if not isinstance(db, Chroma) or \
1297
- len(document_choice) == 0 or \
1298
- len(document_choice) <= 1 and document_choice[0] == 'All':
1299
- # treat empty list as All for now, not 'None'
1300
- filter_kwargs = {}
1301
- elif len(document_choice) > 0 and document_choice[0] == 'Only':
1302
- # Only means All docs, but only will return sources, not LLM response
1303
- filter_kwargs = {}
1304
- else:
1305
- if len(document_choice) >= 2:
1306
- or_filter = [{"source": {"$eq": x}} for x in document_choice]
1307
- filter_kwargs = dict(filter={"$or": or_filter})
1308
- elif len(document_choice) > 0:
1309
- one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
1310
- filter_kwargs = dict(filter=one_filter)
1311
- else:
1312
- filter_kwargs = {}
1313
- if len(document_choice) == 1 and document_choice[0] == 'None':
1314
- k_db = 1
1315
- k = 0
1316
- docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
1317
- # cut off so no high distance docs/sources considered
1318
- docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
1319
- scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
1320
- if len(scores) > 0 and verbose:
1321
- print("Distance: min: %s max: %s mean: %s median: %s" %
1322
- (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
1323
- else:
1324
- docs = []
1325
- scores = []
1326
-
1327
- if not docs and use_context and model_name not in non_hf_types:
1328
- # if HF type and have no docs, can bail out
1329
- return docs, None, [], False
1330
-
1331
- if len(document_choice) > 0 and document_choice[0] == 'Only':
1332
- # no LLM use
1333
- return docs, None, [], False
1334
-
1335
- common_words_file = "data/NGSL_1.2_stats.csv.zip"
1336
- if os.path.isfile(common_words_file):
1337
- df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
1338
- import string
1339
- reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
1340
- reduced_query_words = reduced_query.split(' ')
1341
- set_common = set(df['Lemma'].values.tolist())
1342
- num_common = len([x.lower() in set_common for x in reduced_query_words])
1343
- frac_common = num_common / len(reduced_query) if reduced_query else 0
1344
- # FIXME: report to user bad query that uses too many common words
1345
- if verbose:
1346
- print("frac_common: %s" % frac_common, flush=True)
1347
-
1348
- if len(docs) == 0:
1349
- # avoid context == in prompt then
1350
- use_context = False
1351
-
1352
- if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1353
- # instruct-like, rather than few-shot prompt_type='plain' as default
1354
- # but then sources confuse the model with how inserted among rest of text, so avoid
1355
- prefix = ""
1356
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
1357
- template = """%s{context}{question}""" % prefix
1358
- else:
1359
- template = """%s
1360
- ==
1361
- {context}
1362
- ==
1363
- {question}""" % prefix
1364
- prompt = PromptTemplate(
1365
- # input_variables=["summaries", "question"],
1366
- input_variables=["context", "question"],
1367
- template=template,
1368
- )
1369
- chain = load_qa_chain(llm, prompt=prompt)
1370
- else:
1371
- chain = load_qa_with_sources_chain(llm)
1372
-
1373
- if not use_context:
1374
- chain_kwargs = dict(input_documents=[], question=query)
1375
- else:
1376
- chain_kwargs = dict(input_documents=docs, question=query)
1377
-
1378
- target = wrapped_partial(chain, chain_kwargs)
1379
- return docs, target, scores, use_context
1380
-
1381
-
1382
- def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
1383
- if verbose:
1384
- print("query: %s" % query, flush=True)
1385
- print("answer: %s" % answer['output_text'], flush=True)
1386
-
1387
- if len(answer['input_documents']) == 0:
1388
- extra = ''
1389
- ret = answer['output_text'] + extra
1390
- return ret, extra
1391
-
1392
- # link
1393
- answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
1394
- zip(scores, answer['input_documents'])]
1395
- answer_sources_dict = defaultdict(list)
1396
- [answer_sources_dict[url].append(score) for score, url in answer_sources]
1397
- answers_dict = {}
1398
- for url, scores_url in answer_sources_dict.items():
1399
- answers_dict[url] = np.max(scores_url)
1400
- answer_sources = [(score, url) for url, score in answers_dict.items()]
1401
- answer_sources.sort(key=lambda x: x[0], reverse=True)
1402
- if show_rank:
1403
- # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
1404
- # sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources)
1405
- answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
1406
- sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources)
1407
- else:
1408
- answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources]
1409
- sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
1410
- sorted_sources_urls += f"</ul></p>{source_postfix}"
1411
-
1412
- if not answer['output_text'].endswith('\n'):
1413
- answer['output_text'] += '\n'
1414
-
1415
- if answer_with_sources:
1416
- extra = '\n' + sorted_sources_urls
1417
- else:
1418
- extra = ''
1419
- ret = answer['output_text'] + extra
1420
- return ret, extra
1421
-
1422
-
1423
- def chunk_sources(sources, chunk_size=1024):
1424
- source_chunks = []
1425
- # Below for known separator
1426
- # splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0)
1427
- splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
1428
- for source in sources:
1429
- # print(source.metadata['source'], flush=True)
1430
- for chunky in splitter.split_text(source.page_content):
1431
- source_chunks.append(Document(page_content=chunky, metadata=source.metadata))
1432
- return source_chunks
1433
-
1434
-
1435
- def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'):
1436
- from huggingface_hub import hf_hub_download
1437
- # True for case when locally already logged in with correct token, so don't have to set key
1438
- token = os.getenv('HUGGINGFACE_API_TOKEN', True)
1439
- path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
1440
- import zipfile
1441
- with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
1442
- zip_ref.extractall(dest)
1443
- return path_to_zip_file
1444
-
1445
-
1446
- # Note dir has space in some cases, while zip does not
1447
- some_db_zips = [['db_dir_DriverlessAI_docs.zip', 'db_dir_DriverlessAI docs', 'CC-BY-NC license'],
1448
- ['db_dir_UserData.zip', 'db_dir_UserData', 'CC-BY license for ArXiv'],
1449
- ['db_dir_github_h2oGPT.zip', 'db_dir_github h2oGPT', 'ApacheV2 license'],
1450
- ['db_dir_wiki.zip', 'db_dir_wiki', 'CC-BY-SA Wikipedia license'],
1451
- # ['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'],
1452
- ]
1453
-
1454
- all_db_zips = some_db_zips + \
1455
- [['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'],
1456
- ]
1457
-
1458
-
1459
- def get_some_dbs_from_hf(dest='.', db_zips=None):
1460
- if db_zips is None:
1461
- db_zips = some_db_zips
1462
- for db_dir, dir_expected, license1 in db_zips:
1463
- path_to_zip_file = get_db_from_hf(dest=dest, db_dir=db_dir)
1464
- assert os.path.isfile(path_to_zip_file), "Missing zip in %s" % path_to_zip_file
1465
- if dir_expected:
1466
- assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
1467
- assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
1468
-
1469
-
1470
- if __name__ == '__main__':
1471
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gpt_langchain.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../gpt_langchain.py
gradio_runner.py DELETED
@@ -1,1741 +0,0 @@
1
- import copy
2
- import functools
3
- import inspect
4
- import json
5
- import os
6
- import random
7
- import sys
8
- import traceback
9
- import uuid
10
- import filelock
11
- import pandas as pd
12
- import requests
13
- import tabulate
14
-
15
- # This is a hack to prevent Gradio from phoning home when it gets imported
16
- os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
17
-
18
-
19
- def my_get(url, **kwargs):
20
- print('Gradio HTTP request redirected to localhost :)', flush=True)
21
- kwargs.setdefault('allow_redirects', True)
22
- return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
23
-
24
-
25
- original_get = requests.get
26
- requests.get = my_get
27
- import gradio as gr
28
-
29
- requests.get = original_get
30
-
31
- from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
32
- from prompter import Prompter, \
33
- prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt, non_hf_types
34
- from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
35
- ping, get_short_name, get_url, makedirs, get_kwargs
36
- from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
37
- inputs_kwargs_list, get_cutoffs, scratch_base_dir
38
-
39
- from apscheduler.schedulers.background import BackgroundScheduler
40
-
41
-
42
- def go_gradio(**kwargs):
43
- allow_api = kwargs['allow_api']
44
- is_public = kwargs['is_public']
45
- is_hf = kwargs['is_hf']
46
- memory_restriction_level = kwargs['memory_restriction_level']
47
- n_gpus = kwargs['n_gpus']
48
- admin_pass = kwargs['admin_pass']
49
- model_state0 = kwargs['model_state0']
50
- score_model_state0 = kwargs['score_model_state0']
51
- dbs = kwargs['dbs']
52
- db_type = kwargs['db_type']
53
- visible_langchain_modes = kwargs['visible_langchain_modes']
54
- allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
55
- allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
56
- enable_sources_list = kwargs['enable_sources_list']
57
- enable_url_upload = kwargs['enable_url_upload']
58
- enable_text_upload = kwargs['enable_text_upload']
59
- use_openai_embedding = kwargs['use_openai_embedding']
60
- hf_embedding_model = kwargs['hf_embedding_model']
61
- enable_captions = kwargs['enable_captions']
62
- captions_model = kwargs['captions_model']
63
- enable_ocr = kwargs['enable_ocr']
64
- caption_loader = kwargs['caption_loader']
65
-
66
- # easy update of kwargs needed for evaluate() etc.
67
- queue = True
68
- allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
69
- kwargs.update(locals())
70
-
71
- if 'mbart-' in kwargs['model_lower']:
72
- instruction_label_nochat = "Text to translate"
73
- else:
74
- instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
75
- " use Enter for multiple input lines)"
76
- if kwargs['input_lines'] > 1:
77
- instruction_label = "You (Shift-Enter or push Submit to send message, use Enter for multiple input lines)"
78
- else:
79
- instruction_label = "You (Enter or push Submit to send message, shift-enter for more lines)"
80
-
81
- title = 'h2oGPT'
82
- if 'h2ogpt-research' in kwargs['base_model']:
83
- title += " [Research demonstration]"
84
- more_info = """For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O-LLMStudio](https://github.com/h2oai/h2o-llmstudio)<br>"""
85
- if is_public:
86
- more_info += """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="150" height="20" title="GitHub"></iframe>"""
87
- if kwargs['verbose']:
88
- description = f"""Model {kwargs['base_model']} Instruct dataset.
89
- For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
90
- Command: {str(' '.join(sys.argv))}
91
- Hash: {get_githash()}
92
- """
93
- else:
94
- description = more_info
95
- description += "If this host is busy, try [12B](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
96
- description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)</p>"""
97
- if is_hf:
98
- description += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
99
-
100
- if kwargs['verbose']:
101
- task_info_md = f"""
102
- ### Task: {kwargs['task_info']}"""
103
- else:
104
- task_info_md = ''
105
-
106
- if kwargs['h2ocolors']:
107
- css_code = """footer {visibility: hidden;}
108
- body{background:linear-gradient(#f5f5f5,#e5e5e5);}
109
- body.dark{background:linear-gradient(#000000,#0d0d0d);}
110
- """
111
- else:
112
- css_code = """footer {visibility: hidden}"""
113
- css_code += """
114
- @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
115
- body.dark{#warning {background-color: #555555};}
116
- #small_btn {
117
- margin: 0.6em 0em 0.55em 0;
118
- max-width: 20em;
119
- min-width: 5em !important;
120
- height: 5em;
121
- font-size: 14px !important
122
- }"""
123
-
124
- if kwargs['gradio_avoid_processing_markdown']:
125
- from gradio_client import utils as client_utils
126
- from gradio.components import Chatbot
127
-
128
- # gradio has issue with taking too long to process input/output for markdown etc.
129
- # Avoid for now, allow raw html to render, good enough for chatbot.
130
- def _postprocess_chat_messages(self, chat_message: str):
131
- if chat_message is None:
132
- return None
133
- elif isinstance(chat_message, (tuple, list)):
134
- filepath = chat_message[0]
135
- mime_type = client_utils.get_mimetype(filepath)
136
- filepath = self.make_temp_copy_if_needed(filepath)
137
- return {
138
- "name": filepath,
139
- "mime_type": mime_type,
140
- "alt_text": chat_message[1] if len(chat_message) > 1 else None,
141
- "data": None, # These last two fields are filled in by the frontend
142
- "is_file": True,
143
- }
144
- elif isinstance(chat_message, str):
145
- return chat_message
146
- else:
147
- raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
148
-
149
- Chatbot._postprocess_chat_messages = _postprocess_chat_messages
150
-
151
- if kwargs['gradio_offline_level'] >= 0:
152
- # avoid GoogleFont that pulls from internet
153
- if kwargs['gradio_offline_level'] == 1:
154
- # front end would still have to download fonts or have cached it at some point
155
- base_font = 'Source Sans Pro'
156
- else:
157
- base_font = 'Helvetica'
158
- theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'),
159
- font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'))
160
- else:
161
- theme_kwargs = dict()
162
-
163
- theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs)
164
- demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
165
- callback = gr.CSVLogger()
166
-
167
- model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
168
- if kwargs['base_model'].strip() not in model_options:
169
- lora_options = [kwargs['base_model'].strip()] + model_options
170
- lora_options = kwargs['extra_lora_options']
171
- if kwargs['lora_weights'].strip() not in lora_options:
172
- lora_options = [kwargs['lora_weights'].strip()] + lora_options
173
- # always add in no lora case
174
- # add fake space so doesn't go away in gradio dropdown
175
- no_lora_str = no_model_str = '[None/Remove]'
176
- lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double?
177
- # always add in no model case so can free memory
178
- # add fake space so doesn't go away in gradio dropdown
179
- model_options = [no_model_str] + model_options
180
-
181
- # transcribe, will be detranscribed before use by evaluate()
182
- if not kwargs['lora_weights'].strip():
183
- kwargs['lora_weights'] = no_lora_str
184
-
185
- if not kwargs['base_model'].strip():
186
- kwargs['base_model'] = no_model_str
187
-
188
- # transcribe for gradio
189
- kwargs['gpu_id'] = str(kwargs['gpu_id'])
190
-
191
- no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
192
- output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
193
- 'base_model') else no_model_msg
194
- output_label0_model2 = no_model_msg
195
-
196
- with demo:
197
- # avoid actual model/tokenizer here or anything that would be bad to deepcopy
198
- # https://github.com/gradio-app/gradio/issues/3558
199
- model_state = gr.State(['model', 'tokenizer', kwargs['device'], kwargs['base_model']])
200
- model_state2 = gr.State([None, None, None, None])
201
- model_options_state = gr.State([model_options])
202
- lora_options_state = gr.State([lora_options])
203
- my_db_state = gr.State([None, None])
204
- chat_state = gr.State({})
205
- # make user default first and default choice, dedup
206
- docs_state00 = kwargs['document_choice'] + ['All', 'Only', 'None']
207
- docs_state0 = []
208
- [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
209
- docs_state = gr.State(docs_state0) # first is chosen as default
210
- gr.Markdown(f"""
211
- {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
212
-
213
- {description}
214
- {task_info_md}
215
- """)
216
- if is_hf:
217
- gr.HTML(
218
- )
219
-
220
- # go button visible if
221
- base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
222
- go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
223
- normal_block = gr.Row(visible=not base_wanted)
224
- with normal_block:
225
- with gr.Tabs():
226
- with gr.Row():
227
- col_nochat = gr.Column(visible=not kwargs['chat'])
228
- with col_nochat: # FIXME: for model comparison, and check rest
229
- text_output_nochat = gr.Textbox(lines=5, label=output_label0).style(show_copy_button=True)
230
- instruction_nochat = gr.Textbox(
231
- lines=kwargs['input_lines'],
232
- label=instruction_label_nochat,
233
- placeholder=kwargs['placeholder_instruction'],
234
- )
235
- iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
236
- placeholder=kwargs['placeholder_input'])
237
- submit_nochat = gr.Button("Submit")
238
- flag_btn_nochat = gr.Button("Flag")
239
- if not kwargs['auto_score']:
240
- with gr.Column(visible=kwargs['score_model']):
241
- score_btn_nochat = gr.Button("Score last prompt & response")
242
- score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
243
- else:
244
- with gr.Column(visible=kwargs['score_model']):
245
- score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
246
- col_chat = gr.Column(visible=kwargs['chat'])
247
- with col_chat:
248
- with gr.Row():
249
- text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
250
- text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
251
- height=kwargs['height'] or 400)
252
- with gr.Row():
253
- with gr.Column(scale=50):
254
- instruction = gr.Textbox(
255
- lines=kwargs['input_lines'],
256
- label=instruction_label,
257
- placeholder=kwargs['placeholder_instruction'],
258
- )
259
- with gr.Row():
260
- submit = gr.Button(value='Submit').style(full_width=False, size='sm')
261
- stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
262
- with gr.Row():
263
- clear = gr.Button("Save Chat / New Chat")
264
- flag_btn = gr.Button("Flag")
265
- if not kwargs['auto_score']: # FIXME: For checkbox model2
266
- with gr.Column(visible=kwargs['score_model']):
267
- with gr.Row():
268
- score_btn = gr.Button("Score last prompt & response").style(
269
- full_width=False, size='sm')
270
- score_text = gr.Textbox("Response Score: NA", show_label=False)
271
- score_res2 = gr.Row(visible=False)
272
- with score_res2:
273
- score_btn2 = gr.Button("Score last prompt & response 2").style(
274
- full_width=False, size='sm')
275
- score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
276
- else:
277
- with gr.Column(visible=kwargs['score_model']):
278
- score_text = gr.Textbox("Response Score: NA", show_label=False)
279
- score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
280
- retry = gr.Button("Regenerate")
281
- undo = gr.Button("Undo")
282
- with gr.TabItem("Chat"):
283
- with gr.Row():
284
- if 'mbart-' in kwargs['model_lower']:
285
- src_lang = gr.Dropdown(list(languages_covered().keys()),
286
- value=kwargs['src_lang'],
287
- label="Input Language")
288
- tgt_lang = gr.Dropdown(list(languages_covered().keys()),
289
- value=kwargs['tgt_lang'],
290
- label="Output Language")
291
- radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
292
- type='value')
293
- with gr.Row():
294
- clear_chat_btn = gr.Button(value="Clear Chat", visible=True).style(size='sm')
295
- export_chats_btn = gr.Button(value="Export Chats to Download").style(size='sm')
296
- remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True).style(size='sm')
297
- add_to_chats_btn = gr.Button("Import Chats from Upload").style(size='sm')
298
- with gr.Row():
299
- chats_file = gr.File(interactive=False, label="Download Exported Chats")
300
- chatsup_output = gr.File(label="Upload Chat File(s)",
301
- file_types=['.json'],
302
- file_count='multiple',
303
- elem_id="warning", elem_classes="feedback")
304
- with gr.TabItem("Data Source"):
305
- langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/docs/README_LangChain.md',
306
- from_str=True)
307
- gr.HTML(value=f"""LangChain Support Disabled<p>
308
- Run:<p>
309
- <code>
310
- python generate.py --langchain_mode=MyData
311
- </code>
312
- <p>
313
- For more options see: {langchain_readme}""",
314
- visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
315
- data_row1 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
316
- with data_row1:
317
- if is_hf:
318
- # don't show 'wiki' since only usually useful for internal testing at moment
319
- no_show_modes = ['Disabled', 'wiki']
320
- else:
321
- no_show_modes = ['Disabled']
322
- allowed_modes = visible_langchain_modes.copy()
323
- allowed_modes = [x for x in allowed_modes if x in dbs]
324
- allowed_modes += ['ChatLLM', 'LLM']
325
- if allow_upload_to_my_data and 'MyData' not in allowed_modes:
326
- allowed_modes += ['MyData']
327
- if allow_upload_to_user_data and 'UserData' not in allowed_modes:
328
- allowed_modes += ['UserData']
329
- langchain_mode = gr.Radio(
330
- [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
331
- value=kwargs['langchain_mode'],
332
- label="Data Collection of Sources",
333
- visible=kwargs['langchain_mode'] != 'Disabled')
334
- data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
335
- with data_row2:
336
- with gr.Column(scale=50):
337
- document_choice = gr.Dropdown(docs_state.value,
338
- label="Choose Subset of Doc(s) in Collection [click get sources to update]",
339
- value=docs_state.value[0],
340
- interactive=True,
341
- multiselect=True,
342
- )
343
- with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list):
344
- get_sources_btn = gr.Button(value="Get Sources",
345
- ).style(full_width=False, size='sm')
346
- show_sources_btn = gr.Button(value="Show Sources",
347
- ).style(full_width=False, size='sm')
348
- refresh_sources_btn = gr.Button(value="Refresh Sources",
349
- ).style(full_width=False, size='sm')
350
-
351
- # import control
352
- if kwargs['langchain_mode'] != 'Disabled':
353
- from gpt_langchain import file_types, have_arxiv
354
- else:
355
- have_arxiv = False
356
- file_types = []
357
-
358
- upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
359
- equal_height=False)
360
- with upload_row:
361
- with gr.Column():
362
- file_types_str = '[' + ' '.join(file_types) + ']'
363
- fileup_output = gr.File(label=f'Upload {file_types_str}',
364
- file_types=file_types,
365
- file_count="multiple",
366
- elem_id="warning", elem_classes="feedback")
367
- with gr.Row():
368
- add_to_shared_db_btn = gr.Button("Add File(s) to UserData",
369
- visible=allow_upload_to_user_data, elem_id='small_btn')
370
- add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData",
371
- visible=allow_upload_to_my_data,
372
- elem_id='small_btn' if allow_upload_to_user_data else None,
373
- ).style(
374
- size='sm' if not allow_upload_to_user_data else None)
375
- with gr.Column(
376
- visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload):
377
- url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
378
- url_text = gr.Textbox(label=url_label, interactive=True)
379
- with gr.Row():
380
- url_user_btn = gr.Button(value='Add URL content to Shared UserData',
381
- visible=allow_upload_to_user_data, elem_id='small_btn')
382
- url_my_btn = gr.Button(value='Add URL content to Scratch MyData',
383
- visible=allow_upload_to_my_data,
384
- elem_id='small_btn' if allow_upload_to_user_data else None,
385
- ).style(size='sm' if not allow_upload_to_user_data else None)
386
- with gr.Column(
387
- visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload):
388
- user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]', interactive=True)
389
- with gr.Row():
390
- user_text_user_btn = gr.Button(value='Add Text to Shared UserData',
391
- visible=allow_upload_to_user_data,
392
- elem_id='small_btn')
393
- user_text_my_btn = gr.Button(value='Add Text to Scratch MyData',
394
- visible=allow_upload_to_my_data,
395
- elem_id='small_btn' if allow_upload_to_user_data else None,
396
- ).style(
397
- size='sm' if not allow_upload_to_user_data else None)
398
- with gr.Column(visible=False):
399
- # WIP:
400
- with gr.Row(visible=False).style(equal_height=False):
401
- github_textbox = gr.Textbox(label="Github URL")
402
- with gr.Row(visible=True):
403
- github_shared_btn = gr.Button(value="Add Github to Shared UserData",
404
- visible=allow_upload_to_user_data,
405
- elem_id='small_btn')
406
- github_my_btn = gr.Button(value="Add Github to Scratch MyData",
407
- visible=allow_upload_to_my_data, elem_id='small_btn')
408
- sources_row3 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
409
- equal_height=False)
410
- with sources_row3:
411
- with gr.Column(scale=1):
412
- file_source = gr.File(interactive=False,
413
- label="Download File w/Sources [click get sources to make file]")
414
- with gr.Column(scale=2):
415
- pass
416
- sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
417
- equal_height=False)
418
- with sources_row:
419
- sources_text = gr.HTML(label='Sources Added', interactive=False)
420
-
421
- with gr.TabItem("Expert"):
422
- with gr.Row():
423
- with gr.Column():
424
- stream_output = gr.components.Checkbox(label="Stream output",
425
- value=kwargs['stream_output'])
426
- prompt_type = gr.Dropdown(prompt_types_strings,
427
- value=kwargs['prompt_type'], label="Prompt Type",
428
- visible=not is_public)
429
- prompt_type2 = gr.Dropdown(prompt_types_strings,
430
- value=kwargs['prompt_type'], label="Prompt Type Model 2",
431
- visible=not is_public and False)
432
- do_sample = gr.Checkbox(label="Sample",
433
- info="Enable sampler, required for use of temperature, top_p, top_k",
434
- value=kwargs['do_sample'])
435
- temperature = gr.Slider(minimum=0.01, maximum=3,
436
- value=kwargs['temperature'],
437
- label="Temperature",
438
- info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
439
- top_p = gr.Slider(minimum=0, maximum=1,
440
- value=kwargs['top_p'], label="Top p",
441
- info="Cumulative probability of tokens to sample from")
442
- top_k = gr.Slider(
443
- minimum=0, maximum=100, step=1,
444
- value=kwargs['top_k'], label="Top k",
445
- info='Num. tokens to sample from'
446
- )
447
- # FIXME: https://github.com/h2oai/h2ogpt/issues/106
448
- if os.getenv('TESTINGFAIL'):
449
- max_beams = 8 if not (memory_restriction_level or is_public) else 1
450
- else:
451
- max_beams = 1
452
- num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
453
- value=min(max_beams, kwargs['num_beams']), label="Beams",
454
- info="Number of searches for optimal overall probability. "
455
- "Uses more GPU memory/compute")
456
- # FIXME: 2048 should be tokenizer.model_max_length, but may not even have model yet
457
- if kwargs['max_new_tokens']:
458
- max_max_new_tokens = kwargs['max_new_tokens']
459
- elif memory_restriction_level == 1:
460
- max_max_new_tokens = 768
461
- elif memory_restriction_level == 2:
462
- max_max_new_tokens = 512
463
- elif memory_restriction_level >= 3:
464
- max_max_new_tokens = 256
465
- else:
466
- max_max_new_tokens = 2048
467
- max_new_tokens = gr.Slider(
468
- minimum=1, maximum=max_max_new_tokens, step=1,
469
- value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
470
- )
471
- min_new_tokens = gr.Slider(
472
- minimum=0, maximum=max_max_new_tokens, step=1,
473
- value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
474
- )
475
- early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
476
- value=kwargs['early_stopping'])
477
- max_max_time = 60 * 5 if not is_public else 60 * 2
478
- if is_hf:
479
- max_max_time = min(max_max_time, 60 * 1)
480
- max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
481
- value=min(max_max_time, kwargs['max_time']), label="Max. time",
482
- info="Max. time to search optimal output.")
483
- repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
484
- value=kwargs['repetition_penalty'],
485
- label="Repetition Penalty")
486
- num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
487
- value=kwargs['num_return_sequences'],
488
- label="Number Returns", info="Must be <= num_beams",
489
- visible=not is_public)
490
- iinput = gr.Textbox(lines=4, label="Input",
491
- placeholder=kwargs['placeholder_input'],
492
- visible=not is_public)
493
- context = gr.Textbox(lines=3, label="System Pre-Context",
494
- info="Directly pre-appended without prompt processing",
495
- visible=not is_public)
496
- chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
497
- visible=not is_public)
498
- count_chat_tokens_btn = gr.Button(value="Count Chat Tokens", visible=not is_public)
499
- chat_token_count = gr.Textbox(label="Chat Token Count", value=None,
500
- visible=not is_public, interactive=False)
501
- top_k_docs = gr.Slider(minimum=0, maximum=20, step=1,
502
- value=kwargs['top_k_docs'],
503
- label="Number of document chunks",
504
- info="For LangChain",
505
- visible=not is_public)
506
-
507
- with gr.TabItem("Models"):
508
- load_msg = "Load-Unload Model/LORA [unload works if did not use --base_model]" if not is_public \
509
- else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
510
- load_msg2 = "Load-Unload Model/LORA 2 [unload works if did not use --base_model]" if not is_public \
511
- else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
512
- compare_checkbox = gr.components.Checkbox(label="Compare Mode",
513
- value=False, visible=not is_public)
514
- with gr.Row():
515
- n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
516
- with gr.Column():
517
- with gr.Row():
518
- with gr.Column(scale=50):
519
- model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
520
- value=kwargs['base_model'])
521
- lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
522
- value=kwargs['lora_weights'], visible=kwargs['show_lora'])
523
- with gr.Column(scale=1):
524
- load_model_button = gr.Button(load_msg).style(full_width=False, size='sm')
525
- model_load8bit_checkbox = gr.components.Checkbox(
526
- label="Load 8-bit [requires support]",
527
- value=kwargs['load_8bit'])
528
- model_infer_devices_checkbox = gr.components.Checkbox(
529
- label="Choose Devices [If not Checked, use all GPUs]",
530
- value=kwargs['infer_devices'])
531
- model_gpu = gr.Dropdown(n_gpus_list,
532
- label="GPU ID [-1 = all GPUs, if Choose is enabled]",
533
- value=kwargs['gpu_id'])
534
- model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
535
- interactive=False)
536
- lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
537
- visible=kwargs['show_lora'], interactive=False)
538
- col_model2 = gr.Column(visible=False)
539
- with col_model2:
540
- with gr.Row():
541
- with gr.Column(scale=50):
542
- model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
543
- value=no_model_str)
544
- lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
545
- value=no_lora_str,
546
- visible=kwargs['show_lora'])
547
- with gr.Column(scale=1):
548
- load_model_button2 = gr.Button(load_msg2).style(full_width=False, size='sm')
549
- model_load8bit_checkbox2 = gr.components.Checkbox(
550
- label="Load 8-bit 2 [requires support]",
551
- value=kwargs['load_8bit'])
552
- model_infer_devices_checkbox2 = gr.components.Checkbox(
553
- label="Choose Devices 2 [If not Checked, use all GPUs]",
554
- value=kwargs[
555
- 'infer_devices'])
556
- model_gpu2 = gr.Dropdown(n_gpus_list,
557
- label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
558
- value=kwargs['gpu_id'])
559
- # no model/lora loaded ever in model2 by default
560
- model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
561
- lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
562
- visible=kwargs['show_lora'])
563
- with gr.Row():
564
- with gr.Column(scale=50):
565
- new_model = gr.Textbox(label="New Model HF name/path")
566
- with gr.Row():
567
- add_model_button = gr.Button("Add new model name").style(full_width=False, size='sm')
568
- with gr.Column(scale=50):
569
- new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
570
- with gr.Row():
571
- add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora']).style(
572
- full_width=False, size='sm')
573
- with gr.TabItem("System"):
574
- admin_row = gr.Row()
575
- with admin_row:
576
- admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
577
- admin_btn = gr.Button(value="Admin Access", visible=is_public)
578
- system_row = gr.Row(visible=not is_public)
579
- with system_row:
580
- with gr.Column():
581
- with gr.Row():
582
- system_btn = gr.Button(value='Get System Info')
583
- system_text = gr.Textbox(label='System Info', interactive=False).style(
584
- show_copy_button=True)
585
-
586
- with gr.Row():
587
- zip_btn = gr.Button("Zip")
588
- zip_text = gr.Textbox(label="Zip file name", interactive=False)
589
- file_output = gr.File(interactive=False, label="Zip file to Download")
590
- with gr.Row():
591
- s3up_btn = gr.Button("S3UP")
592
- s3up_text = gr.Textbox(label='S3UP result', interactive=False)
593
- with gr.TabItem("Disclaimers"):
594
- description = ""
595
- description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
596
- if kwargs['load_8bit']:
597
- description += """<i><li> Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.</i></li>"""
598
- description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
599
- if 'h2ogpt-research' in kwargs['base_model']:
600
- description += """<i><li>Research demonstration only, not used for commercial purposes.</i></li>"""
601
- description += """<i><li>By using h2oGPT, you accept our <a href="https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md">Terms of Service</a></i></li></ul></p>"""
602
- gr.Markdown(value=description, show_label=False, interactive=False)
603
-
604
- # Get flagged data
605
- zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
606
- zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False,
607
- api_name='zip_data' if allow_api else None)
608
- s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False,
609
- api_name='s3up_data' if allow_api else None)
610
-
611
- def make_add_visible(x):
612
- return gr.update(visible=x is not None)
613
-
614
- def clear_file_list():
615
- return None
616
-
617
- def make_invisible():
618
- return gr.update(visible=False)
619
-
620
- def make_visible():
621
- return gr.update(visible=True)
622
-
623
- # Add to UserData
624
- update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
625
- use_openai_embedding=use_openai_embedding,
626
- hf_embedding_model=hf_embedding_model,
627
- enable_captions=enable_captions,
628
- captions_model=captions_model,
629
- enable_ocr=enable_ocr,
630
- caption_loader=caption_loader,
631
- )
632
-
633
- # note for update_user_db_func output is ignored for db
634
- add_to_shared_db_btn.click(update_user_db_func,
635
- inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
636
- outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
637
- api_name='add_to_shared' if allow_api else None) \
638
- .then(clear_file_list, outputs=fileup_output, queue=queue)
639
-
640
- # .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
641
- # .then(make_visible, outputs=upload_button, queue=queue)
642
-
643
- def clear_textbox():
644
- return gr.Textbox.update(value='')
645
-
646
- update_user_db_url_func = functools.partial(update_user_db_func, is_url=True)
647
- url_user_btn.click(update_user_db_url_func,
648
- inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
649
- outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
650
- api_name='add_url_to_shared' if allow_api else None) \
651
- .then(clear_textbox, outputs=url_text, queue=queue)
652
-
653
- update_user_db_txt_func = functools.partial(update_user_db_func, is_txt=True)
654
- user_text_user_btn.click(update_user_db_txt_func,
655
- inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
656
- outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
657
- api_name='add_text_to_shared' if allow_api else None) \
658
- .then(clear_textbox, outputs=user_text_text, queue=queue)
659
-
660
- # Add to MyData
661
- update_my_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='MyData',
662
- use_openai_embedding=use_openai_embedding,
663
- hf_embedding_model=hf_embedding_model,
664
- enable_captions=enable_captions,
665
- captions_model=captions_model,
666
- enable_ocr=enable_ocr,
667
- caption_loader=caption_loader,
668
- )
669
-
670
- add_to_my_db_btn.click(update_my_db_func,
671
- inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
672
- outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
673
- api_name='add_to_my' if allow_api else None) \
674
- .then(clear_file_list, outputs=fileup_output, queue=queue)
675
- # .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
676
- # .then(make_visible, outputs=upload_button, queue=queue)
677
-
678
- update_my_db_url_func = functools.partial(update_my_db_func, is_url=True)
679
- url_my_btn.click(update_my_db_url_func,
680
- inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
681
- outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
682
- api_name='add_url_to_my' if allow_api else None) \
683
- .then(clear_textbox, outputs=url_text, queue=queue)
684
-
685
- update_my_db_txt_func = functools.partial(update_my_db_func, is_txt=True)
686
- user_text_my_btn.click(update_my_db_txt_func,
687
- inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
688
- outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
689
- api_name='add_txt_to_my' if allow_api else None) \
690
- .then(clear_textbox, outputs=user_text_text, queue=queue)
691
-
692
- get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0)
693
-
694
- # if change collection source, must clear doc selections from it to avoid inconsistency
695
- def clear_doc_choice():
696
- return gr.Dropdown.update(choices=docs_state0, value=[docs_state0[0]])
697
-
698
- langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
699
-
700
- def update_dropdown(x):
701
- return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
702
-
703
- get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state],
704
- queue=queue,
705
- api_name='get_sources' if allow_api else None) \
706
- .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
707
- # show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
708
- show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
709
- show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
710
- api_name='show_sources' if allow_api else None)
711
-
712
- # Get inputs to evaluate() and make_db()
713
- # don't deepcopy, can contain model itself
714
- all_kwargs = kwargs.copy()
715
- all_kwargs.update(locals())
716
-
717
- refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
718
- **get_kwargs(update_and_get_source_files_given_langchain_mode,
719
- exclude_names=['db1', 'langchain_mode'],
720
- **all_kwargs))
721
- refresh_sources_btn.click(fn=refresh_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text,
722
- api_name='refresh_sources' if allow_api else None)
723
-
724
- def check_admin_pass(x):
725
- return gr.update(visible=x == admin_pass)
726
-
727
- def close_admin(x):
728
- return gr.update(visible=not (x == admin_pass))
729
-
730
- admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
731
- .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
732
-
733
- inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
734
- from functools import partial
735
- kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
736
- # ensure present
737
- for k in inputs_kwargs_list:
738
- assert k in kwargs_evaluate, "Missing %s" % k
739
- fun = partial(evaluate,
740
- **kwargs_evaluate)
741
- fun2 = partial(evaluate,
742
- **kwargs_evaluate)
743
-
744
- dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
745
- size="sm",
746
- )
747
- # FIXME: Could add exceptions for non-chat but still streaming
748
- exception_text = gr.Textbox(value="", visible=kwargs['chat'], label='Chat Exceptions', interactive=False)
749
- dark_mode_btn.click(
750
- None,
751
- None,
752
- None,
753
- _js=get_dark_js(),
754
- api_name="dark" if allow_api else None,
755
- queue=False,
756
- )
757
-
758
- # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
759
- def col_nochat_fun(x):
760
- return gr.Column.update(visible=not x)
761
-
762
- def col_chat_fun(x):
763
- return gr.Column.update(visible=x)
764
-
765
- def context_fun(x):
766
- return gr.Textbox.update(visible=not x)
767
-
768
- chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
769
- .then(col_chat_fun, chat, col_chat) \
770
- .then(context_fun, chat, context) \
771
- .then(col_chat_fun, chat, exception_text)
772
-
773
- # examples after submit or any other buttons for chat or no chat
774
- if kwargs['examples'] is not None and kwargs['show_examples']:
775
- gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
776
-
777
- # Score
778
- def score_last_response(*args, nochat=False, model2=False):
779
- """ Similar to user() """
780
- args_list = list(args)
781
-
782
- if memory_restriction_level > 0:
783
- max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
784
- else:
785
- max_length_tokenize = 2048 - 256
786
- cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
787
- smodel = score_model_state0[0]
788
- stokenizer = score_model_state0[1]
789
- sdevice = score_model_state0[2]
790
- if not nochat:
791
- history = args_list[-1]
792
- if history is None:
793
- if not model2:
794
- # maybe only doing first model, no need to complain
795
- print("Bad history in scoring last response, fix for now", flush=True)
796
- history = []
797
- if smodel is not None and \
798
- stokenizer is not None and \
799
- sdevice is not None and \
800
- history is not None and len(history) > 0 and \
801
- history[-1] is not None and \
802
- len(history[-1]) >= 2:
803
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
804
-
805
- question = history[-1][0]
806
-
807
- answer = history[-1][1]
808
- else:
809
- return 'Response Score: NA'
810
- else:
811
- answer = args_list[-1]
812
- instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
813
- question = args_list[instruction_nochat_arg_id]
814
-
815
- if question is None:
816
- return 'Response Score: Bad Question'
817
- if answer is None:
818
- return 'Response Score: Bad Answer'
819
- score = score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len)
820
- if isinstance(score, str):
821
- return 'Response Score: NA'
822
- return 'Response Score: {:.1%}'.format(score)
823
-
824
- def noop_score_last_response(*args, **kwargs):
825
- return "Response Score: Disabled"
826
-
827
- if kwargs['score_model']:
828
- score_fun = score_last_response
829
- else:
830
- score_fun = noop_score_last_response
831
-
832
- score_args = dict(fn=score_fun,
833
- inputs=inputs_list + [text_output],
834
- outputs=[score_text],
835
- )
836
- score_args2 = dict(fn=partial(score_fun, model2=True),
837
- inputs=inputs_list + [text_output2],
838
- outputs=[score_text2],
839
- )
840
-
841
- score_args_nochat = dict(fn=partial(score_fun, nochat=True),
842
- inputs=inputs_list + [text_output_nochat],
843
- outputs=[score_text_nochat],
844
- )
845
- if not kwargs['auto_score']:
846
- score_event = score_btn.click(**score_args, queue=queue, api_name='score' if allow_api else None) \
847
- .then(**score_args2, queue=queue, api_name='score2' if allow_api else None)
848
- score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=queue,
849
- api_name='score_nochat' if allow_api else None)
850
-
851
- def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
852
- """
853
- User that fills history for bot
854
- :param args:
855
- :param undo:
856
- :param sanitize_user_prompt:
857
- :param model2:
858
- :return:
859
- """
860
- args_list = list(args)
861
- user_message = args_list[eval_func_param_names.index('instruction')] # chat only
862
- input1 = args_list[eval_func_param_names.index('iinput')] # chat only
863
- context1 = args_list[eval_func_param_names.index('context')]
864
- prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
865
- chat1 = args_list[eval_func_param_names.index('chat')]
866
- stream_output1 = args_list[eval_func_param_names.index('stream_output')]
867
- if input1 and not user_message.endswith(':'):
868
- user_message1 = user_message + ":" + input1
869
- elif input1:
870
- user_message1 = user_message + input1
871
- else:
872
- user_message1 = user_message
873
- if sanitize_user_prompt:
874
- from better_profanity import profanity
875
- user_message1 = profanity.censor(user_message1)
876
- # FIXME: WIP to use desired seperator when user enters nothing
877
- prompter = Prompter(prompt_type1, debug=kwargs['debug'], chat=chat1, stream_output=stream_output1)
878
- if user_message1 in ['']:
879
- # e.g. when user just hits enter in textbox,
880
- # else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
881
- user_message1 = '\n'
882
- # ensure good visually, else markdown ignores multiple \n
883
- user_message1 = user_message1.replace('\n', '<br>')
884
-
885
- history = args_list[-1]
886
- if undo and history:
887
- history.pop()
888
- args_list = args_list[:-1] # FYI, even if unused currently
889
- if history is None:
890
- if not model2:
891
- # no need to complain so often unless model1
892
- print("Bad history, fix for now", flush=True)
893
- history = []
894
- # ensure elements not mixed across models as output,
895
- # even if input is currently same source
896
- history = history.copy()
897
- if undo:
898
- return history
899
- else:
900
- # FIXME: compare, same history for now
901
- return history + [[user_message1, None]]
902
-
903
- def history_to_context(history, langchain_mode1, prompt_type1, chat1):
904
- # ensure output will be unique to models
905
- # FIXME: hard-coded 2048 implicitly passed:
906
- _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level, for_context=True)
907
- history = copy.deepcopy(history)
908
-
909
- context1 = ''
910
- if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
911
- context1 = ''
912
- # - 1 below because current instruction already in history from user()
913
- for histi in range(0, len(history) - 1):
914
- data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
915
- prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
916
- chat1, reduced=True)
917
- # md -> back to text, maybe not super important if model trained enough
918
- if not kwargs['keep_sources_in_context']:
919
- from gpt_langchain import source_prefix, source_postfix
920
- import re
921
- prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
922
- flags=re.DOTALL)
923
- if prompt.endswith('\n<p>'):
924
- prompt = prompt[:-4]
925
- prompt = prompt.replace('<br>', chat_sep)
926
- if not prompt.endswith(chat_sep):
927
- prompt += chat_sep
928
- # most recent first, add older if can
929
- # only include desired chat history
930
- if len(prompt + context1) > max_prompt_length:
931
- break
932
- context1 = prompt + context1
933
-
934
- _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
935
- reduced=True)
936
- if context1 and not context1.endswith(chat_sep):
937
- context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
938
- return context1
939
-
940
- def bot(*args, retry=False):
941
- """
942
- bot that consumes history for user input
943
- instruction (from input_list) itself is not consumed by bot
944
- :param args:
945
- :param retry:
946
- :return:
947
- """
948
- # don't deepcopy, can contain model itself
949
- args_list = list(args).copy()
950
- model_state1 = args_list[-3]
951
- my_db_state1 = args_list[-2]
952
- history = args_list[-1]
953
-
954
- if model_state1[0] is None or model_state1[0] == no_model_str:
955
- history = []
956
- yield history, ''
957
- return
958
-
959
- args_list = args_list[:-3] # only keep rest needed for evaluate()
960
- langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
961
- if retry and history:
962
- history.pop()
963
- if not args_list[eval_func_param_names.index('do_sample')]:
964
- # if was not sampling, no point in retry unless change to sample
965
- args_list[eval_func_param_names.index('do_sample')] = True
966
- if not history:
967
- print("No history", flush=True)
968
- history = []
969
- yield history, ''
970
- return
971
- instruction1 = history[-1][0]
972
- if not instruction1:
973
- # reject empty query, can sometimes go nuts
974
- history = []
975
- yield history, ''
976
- return
977
- prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
978
- chat1 = args_list[eval_func_param_names.index('chat')]
979
- context1 = history_to_context(history, langchain_mode1, prompt_type1, chat1)
980
- args_list[0] = instruction1 # override original instruction with history from user
981
- args_list[2] = context1
982
- fun1 = partial(evaluate,
983
- model_state1,
984
- my_db_state1,
985
- **kwargs_evaluate)
986
- try:
987
- for output_fun in fun1(*tuple(args_list)):
988
- output = output_fun['response']
989
- extra = output_fun['sources'] # FIXME: can show sources in separate text box etc.
990
- # ensure good visually, else markdown ignores multiple \n
991
- bot_message = output.replace('\n', '<br>')
992
- history[-1][1] = bot_message
993
- yield history, ''
994
- except StopIteration:
995
- yield history, ''
996
- except RuntimeError as e:
997
- if "generator raised StopIteration" in str(e):
998
- # assume last entry was bad, undo
999
- history.pop()
1000
- yield history, ''
1001
- else:
1002
- if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None:
1003
- history[-1][1] = ''
1004
- yield history, str(e)
1005
- raise
1006
- except Exception as e:
1007
- # put error into user input
1008
- ex = "Exception: %s" % str(e)
1009
- if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None:
1010
- history[-1][1] = ''
1011
- yield history, ex
1012
- raise
1013
- return
1014
-
1015
- # NORMAL MODEL
1016
- user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
1017
- inputs=inputs_list + [text_output],
1018
- outputs=text_output,
1019
- )
1020
- bot_args = dict(fn=bot,
1021
- inputs=inputs_list + [model_state, my_db_state] + [text_output],
1022
- outputs=[text_output, exception_text],
1023
- )
1024
- retry_bot_args = dict(fn=functools.partial(bot, retry=True),
1025
- inputs=inputs_list + [model_state, my_db_state] + [text_output],
1026
- outputs=[text_output, exception_text],
1027
- )
1028
- undo_user_args = dict(fn=functools.partial(user, undo=True),
1029
- inputs=inputs_list + [text_output],
1030
- outputs=text_output,
1031
- )
1032
-
1033
- # MODEL2
1034
- user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
1035
- inputs=inputs_list + [text_output2],
1036
- outputs=text_output2,
1037
- )
1038
- bot_args2 = dict(fn=bot,
1039
- inputs=inputs_list + [model_state2, my_db_state] + [text_output2],
1040
- outputs=[text_output2, exception_text],
1041
- )
1042
- retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
1043
- inputs=inputs_list + [model_state2, my_db_state] + [text_output2],
1044
- outputs=[text_output2, exception_text],
1045
- )
1046
- undo_user_args2 = dict(fn=functools.partial(user, undo=True),
1047
- inputs=inputs_list + [text_output2],
1048
- outputs=text_output2,
1049
- )
1050
-
1051
- def clear_instruct():
1052
- return gr.Textbox.update(value='')
1053
-
1054
- if kwargs['auto_score']:
1055
- score_args_submit = score_args
1056
- score_args2_submit = score_args2
1057
- else:
1058
- score_args_submit = dict(fn=lambda: None, inputs=None, outputs=None)
1059
- score_args2_submit = dict(fn=lambda: None, inputs=None, outputs=None)
1060
-
1061
- # in case 2nd model, consume instruction first, so can clear quickly
1062
- # bot doesn't consume instruction itself, just history from user, so why works
1063
- submit_event1a = instruction.submit(**user_args, queue=queue,
1064
- api_name='instruction' if allow_api else None)
1065
- submit_event1b = submit_event1a.then(**user_args2, api_name='instruction2' if allow_api else None)
1066
- submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \
1067
- .then(clear_instruct, None, iinput)
1068
- submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
1069
- queue=queue)
1070
- submit_event1e = submit_event1d.then(**score_args_submit,
1071
- api_name='instruction_bot_score' if allow_api else None,
1072
- queue=queue)
1073
- submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
1074
- queue=queue)
1075
- submit_event1g = submit_event1f.then(**score_args2_submit,
1076
- api_name='instruction_bot_score2' if allow_api else None, queue=queue)
1077
- submit_event1h = submit_event1g.then(clear_torch_cache)
1078
-
1079
- submit_event2a = submit.click(**user_args, api_name='submit' if allow_api else None)
1080
- submit_event2b = submit_event2a.then(**user_args2, api_name='submit2' if allow_api else None)
1081
- submit_event2c = submit_event2b.then(clear_instruct, None, instruction) \
1082
- .then(clear_instruct, None, iinput)
1083
- submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue)
1084
- submit_event2e = submit_event2d.then(**score_args_submit, api_name='submit_bot_score' if allow_api else None,
1085
- queue=queue)
1086
- submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue)
1087
- submit_event2g = submit_event2f.then(**score_args2_submit, api_name='submit_bot_score2' if allow_api else None,
1088
- queue=queue)
1089
- submit_event2h = submit_event2g.then(clear_torch_cache)
1090
-
1091
- submit_event3a = retry.click(**user_args, api_name='retry' if allow_api else None)
1092
- submit_event3b = submit_event3a.then(**user_args2, api_name='retry2' if allow_api else None)
1093
- submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \
1094
- .then(clear_instruct, None, iinput)
1095
- submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None,
1096
- queue=queue)
1097
- submit_event3e = submit_event3d.then(**score_args_submit, api_name='retry_bot_score' if allow_api else None,
1098
- queue=queue)
1099
- submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None,
1100
- queue=queue)
1101
- submit_event3g = submit_event3f.then(**score_args2_submit, api_name='retry_bot_score2' if allow_api else None,
1102
- queue=queue)
1103
- submit_event3h = submit_event3g.then(clear_torch_cache)
1104
-
1105
- submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
1106
- .then(**undo_user_args2, api_name='undo2' if allow_api else None) \
1107
- .then(clear_instruct, None, instruction) \
1108
- .then(clear_instruct, None, iinput) \
1109
- .then(**score_args_submit, api_name='undo_score' if allow_api else None) \
1110
- .then(**score_args2_submit, api_name='undo_score2' if allow_api else None)
1111
-
1112
- # MANAGE CHATS
1113
- def dedup(short_chat, short_chats):
1114
- if short_chat not in short_chats:
1115
- return short_chat
1116
- for i in range(1, 1000):
1117
- short_chat_try = short_chat + "_" + str(i)
1118
- if short_chat_try not in short_chats:
1119
- return short_chat_try
1120
- # fallback and hope for best
1121
- short_chat = short_chat + "_" + str(random.random())
1122
- return short_chat
1123
-
1124
- def get_short_chat(x, short_chats, short_len=20, words=4):
1125
- if x and len(x[0]) == 2 and x[0][0] is not None:
1126
- short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
1127
- short_chat = dedup(short_chat, short_chats)
1128
- else:
1129
- short_chat = None
1130
- return short_chat
1131
-
1132
- def is_chat_same(x, y):
1133
- # <p> etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation
1134
- is_same = True
1135
- # length of conversation has to be same
1136
- if len(x) != len(y):
1137
- return False
1138
- for stepx, stepy in zip(x, y):
1139
- if len(stepx) != len(stepy):
1140
- # something off with a conversation
1141
- return False
1142
- if len(stepx) != 2:
1143
- # something off
1144
- return False
1145
- if len(stepy) != 2:
1146
- # something off
1147
- return False
1148
- questionx = stepx[0].replace('<p>', '').replace('</p>', '') if stepx[0] is not None else None
1149
- answerx = stepx[1].replace('<p>', '').replace('</p>', '') if stepx[1] is not None else None
1150
-
1151
- questiony = stepy[0].replace('<p>', '').replace('</p>', '') if stepy[0] is not None else None
1152
- answery = stepy[1].replace('<p>', '').replace('</p>', '') if stepy[1] is not None else None
1153
-
1154
- if questionx != questiony or answerx != answery:
1155
- return False
1156
- return is_same
1157
-
1158
- def save_chat(chat1, chat2, chat_state1):
1159
- short_chats = list(chat_state1.keys())
1160
- for chati in [chat1, chat2]:
1161
- if chati and len(chati) > 0 and len(chati[0]) == 2 and chati[0][1] is not None:
1162
- short_chat = get_short_chat(chati, short_chats)
1163
- if short_chat:
1164
- already_exists = any([is_chat_same(chati, x) for x in chat_state1.values()])
1165
- if not already_exists:
1166
- chat_state1[short_chat] = chati
1167
- return chat_state1
1168
-
1169
- def update_radio_chats(chat_state1):
1170
- return gr.update(choices=list(chat_state1.keys()), value=None)
1171
-
1172
- def deselect_radio_chats():
1173
- return gr.update(value=None)
1174
-
1175
- def switch_chat(chat_key, chat_state1):
1176
- chosen_chat = chat_state1[chat_key]
1177
- return chosen_chat, chosen_chat
1178
-
1179
- radio_chats.input(switch_chat, inputs=[radio_chats, chat_state], outputs=[text_output, text_output2])
1180
-
1181
- def remove_chat(chat_key, chat_state1):
1182
- chat_state1.pop(chat_key, None)
1183
- return chat_state1
1184
-
1185
- remove_chat_btn.click(remove_chat, inputs=[radio_chats, chat_state], outputs=chat_state) \
1186
- .then(update_radio_chats, inputs=chat_state, outputs=radio_chats)
1187
-
1188
- def get_chats1(chat_state1):
1189
- base = 'chats'
1190
- makedirs(base, exist_ok=True)
1191
- filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4()))
1192
- with open(filename, "wt") as f:
1193
- f.write(json.dumps(chat_state1, indent=2))
1194
- return filename
1195
-
1196
- export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
1197
- api_name='export_chats' if allow_api else None)
1198
-
1199
- def add_chats_from_file(file, chat_state1, add_btn):
1200
- if not file:
1201
- return chat_state1, add_btn
1202
- if isinstance(file, str):
1203
- files = [file]
1204
- else:
1205
- files = file
1206
- if not files:
1207
- return chat_state1, add_btn
1208
- for file1 in files:
1209
- try:
1210
- if hasattr(file1, 'name'):
1211
- file1 = file1.name
1212
- with open(file1, "rt") as f:
1213
- new_chats = json.loads(f.read())
1214
- for chat1_k, chat1_v in new_chats.items():
1215
- # ignore chat1_k, regenerate and de-dup to avoid loss
1216
- chat_state1 = save_chat(chat1_v, None, chat_state1)
1217
- except BaseException as e:
1218
- print("Add chats exception: %s" % str(e), flush=True)
1219
- return chat_state1, add_btn
1220
-
1221
- # note for update_user_db_func output is ignored for db
1222
- add_to_chats_btn.click(add_chats_from_file,
1223
- inputs=[chatsup_output, chat_state, add_to_chats_btn],
1224
- outputs=[chat_state, add_to_my_db_btn], queue=False,
1225
- api_name='add_to_chats' if allow_api else None) \
1226
- .then(clear_file_list, outputs=chatsup_output, queue=False) \
1227
- .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False)
1228
-
1229
- clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
1230
- .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \
1231
- .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False)
1232
-
1233
- # does both models
1234
- clear.click(save_chat, inputs=[text_output, text_output2, chat_state], outputs=chat_state,
1235
- api_name='save_chat' if allow_api else None) \
1236
- .then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
1237
- api_name='update_chats' if allow_api else None) \
1238
- .then(lambda: None, None, text_output, queue=False, api_name='clearB' if allow_api else None) \
1239
- .then(lambda: None, None, text_output2, queue=False, api_name='clearB2' if allow_api else None)
1240
- # NOTE: clear of instruction/iinput for nochat has to come after score,
1241
- # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
1242
- submit_event_nochat = submit_nochat.click(fun,
1243
- inputs=[model_state, my_db_state] + inputs_list,
1244
- outputs=text_output_nochat,
1245
- queue=queue,
1246
- api_name='submit_nochat' if allow_api else None) \
1247
- .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \
1248
- .then(clear_instruct, None, instruction_nochat) \
1249
- .then(clear_instruct, None, iinput_nochat) \
1250
- .then(clear_torch_cache)
1251
-
1252
- def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
1253
- # ensure old model removed from GPU memory
1254
- if kwargs['debug']:
1255
- print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
1256
-
1257
- model0 = model_state0[0]
1258
- if isinstance(model_state_old[0], str) and model0 is not None:
1259
- # best can do, move model loaded at first to CPU
1260
- model0.cpu()
1261
-
1262
- if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
1263
- try:
1264
- model_state_old[0].cpu()
1265
- except Exception as e:
1266
- # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
1267
- print("Unable to put model on CPU: %s" % str(e), flush=True)
1268
- del model_state_old[0]
1269
- model_state_old[0] = None
1270
-
1271
- if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
1272
- del model_state_old[1]
1273
- model_state_old[1] = None
1274
-
1275
- clear_torch_cache()
1276
- if kwargs['debug']:
1277
- print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True)
1278
-
1279
- if model_name is None or model_name == no_model_str:
1280
- # no-op if no model, just free memory
1281
- # no detranscribe needed for model, never go into evaluate
1282
- lora_weights = no_lora_str
1283
- return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
1284
-
1285
- # don't deepcopy, can contain model itself
1286
- all_kwargs1 = all_kwargs.copy()
1287
- all_kwargs1['base_model'] = model_name.strip()
1288
- all_kwargs1['load_8bit'] = load_8bit
1289
- all_kwargs1['infer_devices'] = infer_devices
1290
- all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
1291
- model_lower = model_name.strip().lower()
1292
- if model_lower in inv_prompt_type_to_model_lower:
1293
- prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
1294
- else:
1295
- prompt_type1 = prompt_type_old
1296
-
1297
- # detranscribe
1298
- if lora_weights == no_lora_str:
1299
- lora_weights = ''
1300
-
1301
- all_kwargs1['lora_weights'] = lora_weights.strip()
1302
- model1, tokenizer1, device1 = get_model(reward_type=False,
1303
- **get_kwargs(get_model, exclude_names=['reward_type'],
1304
- **all_kwargs1))
1305
- clear_torch_cache()
1306
-
1307
- if kwargs['debug']:
1308
- print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True)
1309
- return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
1310
-
1311
- def dropdown_prompt_type_list(x):
1312
- return gr.Dropdown.update(value=x)
1313
-
1314
- def chatbot_list(x, model_used_in):
1315
- return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
1316
-
1317
- load_model_args = dict(fn=load_model,
1318
- inputs=[model_choice, lora_choice, model_state, prompt_type,
1319
- model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
1320
- outputs=[model_state, model_used, lora_used, prompt_type])
1321
- prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
1322
- chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
1323
- nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
1324
- if not is_public:
1325
- load_model_event = load_model_button.click(**load_model_args, api_name='load_model' if allow_api else None) \
1326
- .then(**prompt_update_args) \
1327
- .then(**chatbot_update_args) \
1328
- .then(**nochat_update_args) \
1329
- .then(clear_torch_cache)
1330
-
1331
- load_model_args2 = dict(fn=load_model,
1332
- inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
1333
- model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
1334
- outputs=[model_state2, model_used2, lora_used2, prompt_type2])
1335
- prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
1336
- chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
1337
- if not is_public:
1338
- load_model_event2 = load_model_button2.click(**load_model_args2,
1339
- api_name='load_model2' if allow_api else None) \
1340
- .then(**prompt_update_args2) \
1341
- .then(**chatbot_update_args2) \
1342
- .then(clear_torch_cache)
1343
-
1344
- def dropdown_model_list(list0, x):
1345
- new_state = [list0[0] + [x]]
1346
- new_options = [*new_state[0]]
1347
- return gr.Dropdown.update(value=x, choices=new_options), \
1348
- gr.Dropdown.update(value=x, choices=new_options), \
1349
- '', new_state
1350
-
1351
- add_model_event = add_model_button.click(fn=dropdown_model_list,
1352
- inputs=[model_options_state, new_model],
1353
- outputs=[model_choice, model_choice2, new_model, model_options_state],
1354
- queue=False)
1355
-
1356
- def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
1357
- new_state = [list0[0] + [x]]
1358
- new_options = [*new_state[0]]
1359
- # don't switch drop-down to added lora if already have model loaded
1360
- x1 = x if model_used1 == no_model_str else lora_used1
1361
- x2 = x if model_used2 == no_model_str else lora_used2
1362
- return gr.Dropdown.update(value=x1, choices=new_options), \
1363
- gr.Dropdown.update(value=x2, choices=new_options), \
1364
- '', new_state
1365
-
1366
- add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
1367
- inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
1368
- lora_used2],
1369
- outputs=[lora_choice, lora_choice2, new_lora, lora_options_state],
1370
- queue=False)
1371
-
1372
- go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, queue=False) \
1373
- .then(lambda: gr.update(visible=True), None, normal_block, queue=False) \
1374
- .then(**load_model_args, queue=False).then(**prompt_update_args, queue=False)
1375
-
1376
- def compare_textbox_fun(x):
1377
- return gr.Textbox.update(visible=x)
1378
-
1379
- def compare_column_fun(x):
1380
- return gr.Column.update(visible=x)
1381
-
1382
- def compare_prompt_fun(x):
1383
- return gr.Dropdown.update(visible=x)
1384
-
1385
- compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2,
1386
- api_name="compare_checkbox" if allow_api else None) \
1387
- .then(compare_column_fun, compare_checkbox, col_model2) \
1388
- .then(compare_prompt_fun, compare_checkbox, prompt_type2) \
1389
- .then(compare_textbox_fun, compare_checkbox, score_text2)
1390
- # FIXME: add score_res2 in condition, but do better
1391
-
1392
- # callback for logging flagged input/output
1393
- callback.setup(inputs_list + [text_output, text_output2], "flagged_data_points")
1394
- flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2], None,
1395
- preprocess=False,
1396
- api_name='flag' if allow_api else None, queue=False)
1397
- flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None,
1398
- preprocess=False,
1399
- api_name='flag_nochat' if allow_api else None, queue=False)
1400
-
1401
- def get_system_info():
1402
- return gr.Textbox.update(value=system_info_print())
1403
-
1404
- system_event = system_btn.click(get_system_info, outputs=system_text,
1405
- api_name='system_info' if allow_api else None, queue=False)
1406
-
1407
- # don't pass text_output, don't want to clear output, just stop it
1408
- # cancel only stops outer generation, not inner generation or non-generation
1409
- stop_btn.click(lambda: None, None, None,
1410
- cancels=[submit_event1d, submit_event1f,
1411
- submit_event2d, submit_event2f,
1412
- submit_event3d, submit_event3f,
1413
- submit_event_nochat],
1414
- queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
1415
-
1416
- def count_chat_tokens(model_state1, chat1, prompt_type1):
1417
- if model_state1 and not isinstance(model_state1[1], str):
1418
- tokenizer = model_state1[1]
1419
- elif model_state0 and not isinstance(model_state0[1], str):
1420
- tokenizer = model_state0[1]
1421
- else:
1422
- tokenizer = None
1423
- if tokenizer is not None:
1424
- langchain_mode1 = 'ChatLLM'
1425
- # fake user message to mimic bot()
1426
- chat1 = copy.deepcopy(chat1)
1427
- chat1 = chat1 + [['user_message1', None]]
1428
- context1 = history_to_context(chat1, langchain_mode1, prompt_type1, chat1)
1429
- return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
1430
- else:
1431
- return "N/A"
1432
-
1433
- count_chat_tokens_btn.click(fn=count_chat_tokens, inputs=[model_state, text_output, prompt_type],
1434
- outputs=chat_token_count, api_name='count_tokens' if allow_api else None)
1435
-
1436
- demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
1437
-
1438
- demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
1439
- favicon_path = "h2o-logo.svg"
1440
-
1441
- scheduler = BackgroundScheduler()
1442
- scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
1443
- if is_public and \
1444
- kwargs['base_model'] not in non_hf_types:
1445
- # FIXME: disable for gptj, langchain or gpt4all modify print itself
1446
- # FIXME: and any multi-threaded/async print will enter model output!
1447
- scheduler.add_job(func=ping, trigger="interval", seconds=60)
1448
- scheduler.start()
1449
-
1450
- # import control
1451
- if kwargs['langchain_mode'] == 'Disabled' and \
1452
- os.environ.get("TEST_LANGCHAIN_IMPORT") and \
1453
- kwargs['base_model'] not in non_hf_types:
1454
- assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
1455
- assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
1456
-
1457
- demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
1458
- favicon_path=favicon_path, prevent_thread_lock=True,
1459
- auth=kwargs['auth'])
1460
- if kwargs['verbose']:
1461
- print("Started GUI", flush=True)
1462
- if kwargs['block_gradio_exit']:
1463
- demo.block_thread()
1464
-
1465
-
1466
- input_args_list = ['model_state', 'my_db_state']
1467
-
1468
-
1469
- def get_inputs_list(inputs_dict, model_lower):
1470
- """
1471
- map gradio objects in locals() to inputs for evaluate().
1472
- :param inputs_dict:
1473
- :param model_lower:
1474
- :return:
1475
- """
1476
- inputs_list_names = list(inspect.signature(evaluate).parameters)
1477
- inputs_list = []
1478
- for k in inputs_list_names:
1479
- if k == 'kwargs':
1480
- continue
1481
- if k in input_args_list + inputs_kwargs_list:
1482
- # these are added at use time for args or partial for kwargs, not taken as input
1483
- continue
1484
- if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
1485
- continue
1486
- inputs_list.append(inputs_dict[k])
1487
- return inputs_list
1488
-
1489
-
1490
- def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
1491
- if langchain_mode in ['ChatLLM', 'LLM']:
1492
- source_files_added = "NA"
1493
- source_list = []
1494
- elif langchain_mode in ['wiki_full']:
1495
- source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
1496
- " Ask jon.mckinney@h2o.ai for file if required."
1497
- source_list = []
1498
- elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
1499
- db_get = db1[0].get()
1500
- source_list = sorted(set([x['source'] for x in db_get['metadatas']]))
1501
- source_files_added = '\n'.join(source_list)
1502
- elif langchain_mode in dbs and dbs[langchain_mode] is not None:
1503
- db1 = dbs[langchain_mode]
1504
- db_get = db1.get()
1505
- source_list = sorted(set([x['source'] for x in db_get['metadatas']]))
1506
- source_files_added = '\n'.join(source_list)
1507
- else:
1508
- source_list = []
1509
- source_files_added = "None"
1510
- sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
1511
- with open(sources_file, "wt") as f:
1512
- f.write(source_files_added)
1513
- source_list = docs_state0 + source_list
1514
- return sources_file, source_list
1515
-
1516
-
1517
- def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
1518
- try:
1519
- return _update_user_db(file, db1, x, y, *args, dbs=dbs, langchain_mode=langchain_mode, **kwargs)
1520
- except BaseException as e:
1521
- print(traceback.format_exc(), flush=True)
1522
- # gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox
1523
- ex_str = "Exception: %s" % str(e)
1524
- source_files_added = """\
1525
- <html>
1526
- <body>
1527
- <p>
1528
- Sources: <br>
1529
- </p>
1530
- <div style="overflow-y: auto;height:400px">
1531
- {0}
1532
- </div>
1533
- </body>
1534
- </html>
1535
- """.format(ex_str)
1536
- if langchain_mode == 'MyData':
1537
- return db1, x, y, source_files_added
1538
- else:
1539
- return x, y, source_files_added
1540
-
1541
-
1542
- def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='UserData', use_openai_embedding=False,
1543
- hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1544
- caption_loader=None,
1545
- enable_captions=True,
1546
- captions_model="Salesforce/blip-image-captioning-base",
1547
- enable_ocr=False,
1548
- verbose=False,
1549
- chunk=True, chunk_size=512, is_url=False, is_txt=False):
1550
- assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
1551
- assert db_type in ['faiss', 'chroma'], "db_type %s not supported" % db_type
1552
- from gpt_langchain import add_to_db, get_db, path_to_docs
1553
- # handle case of list of temp buffer
1554
- if isinstance(file, list) and len(file) > 0 and hasattr(file[0], 'name'):
1555
- file = [x.name for x in file]
1556
- # handle single file of temp buffer
1557
- if hasattr(file, 'name'):
1558
- file = file.name
1559
- if verbose:
1560
- print("Adding %s" % file, flush=True)
1561
- sources = path_to_docs(file if not is_url and not is_txt else None,
1562
- verbose=verbose, chunk=chunk, chunk_size=chunk_size,
1563
- url=file if is_url else None,
1564
- text=file if is_txt else None,
1565
- enable_captions=enable_captions,
1566
- captions_model=captions_model,
1567
- enable_ocr=enable_ocr,
1568
- caption_loader=caption_loader,
1569
- )
1570
- exceptions = [x for x in sources if x.metadata.get('exception')]
1571
- sources = [x for x in sources if 'exception' not in x.metadata]
1572
-
1573
- with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
1574
- if langchain_mode == 'MyData':
1575
- if db1[0] is not None:
1576
- # then add
1577
- db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type)
1578
- else:
1579
- assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
1580
- # then create
1581
- # assign fresh hash for this user session, so not shared
1582
- # if added has to original state and didn't change, then would be shared db for all users
1583
- db1[1] = str(uuid.uuid4())
1584
- persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
1585
- db1[0] = get_db(sources, use_openai_embedding=use_openai_embedding,
1586
- db_type=db_type,
1587
- persist_directory=persist_directory,
1588
- langchain_mode=langchain_mode,
1589
- hf_embedding_model=hf_embedding_model)
1590
- if db1[0] is None:
1591
- db1[1] = None
1592
- source_files_added = get_source_files(db=db1[0], exceptions=exceptions)
1593
- return db1, x, y, source_files_added
1594
- else:
1595
- persist_directory = 'db_dir_%s' % langchain_mode
1596
- if langchain_mode in dbs and dbs[langchain_mode] is not None:
1597
- # then add
1598
- db, num_new_sources, new_sources_metadata = add_to_db(dbs[langchain_mode], sources, db_type=db_type)
1599
- else:
1600
- # then create
1601
- db = get_db(sources, use_openai_embedding=use_openai_embedding,
1602
- db_type=db_type,
1603
- persist_directory=persist_directory,
1604
- langchain_mode=langchain_mode,
1605
- hf_embedding_model=hf_embedding_model)
1606
- dbs[langchain_mode] = db
1607
- # NOTE we do not return db, because function call always same code path
1608
- # return dbs[langchain_mode], x, y
1609
- # db in this code path is updated in place
1610
- source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions)
1611
- return x, y, source_files_added
1612
-
1613
-
1614
- def get_db(db1, langchain_mode, dbs=None):
1615
- with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
1616
- if langchain_mode in ['wiki_full']:
1617
- # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
1618
- db = None
1619
- elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
1620
- db = db1[0]
1621
- elif langchain_mode in dbs and dbs[langchain_mode] is not None:
1622
- db = dbs[langchain_mode]
1623
- else:
1624
- db = None
1625
- return db
1626
-
1627
-
1628
- def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
1629
- db = get_db(db1, langchain_mode, dbs=dbs)
1630
- return get_source_files(db=db, exceptions=None)
1631
-
1632
-
1633
- def get_source_files(db=None, exceptions=None, metadatas=None):
1634
- if exceptions is None:
1635
- exceptions = []
1636
-
1637
- # only should be one source, not confused
1638
- assert db is not None or metadatas is not None
1639
-
1640
- if metadatas is None:
1641
- source_label = "Sources:"
1642
- if db is not None:
1643
- metadatas = db.get()['metadatas']
1644
- else:
1645
- metadatas = []
1646
- adding_new = False
1647
- else:
1648
- source_label = "New Sources:"
1649
- adding_new = True
1650
-
1651
- # below automatically de-dups
1652
- from gpt_langchain import get_url
1653
- small_dict = {get_url(x['source'], from_str=True, short_name=True): get_short_name(x.get('head')) for x in
1654
- metadatas}
1655
- # if small_dict is empty dict, that's ok
1656
- df = pd.DataFrame(small_dict.items(), columns=['source', 'head'])
1657
- df.index = df.index + 1
1658
- df.index.name = 'index'
1659
- source_files_added = tabulate.tabulate(df, headers='keys', tablefmt='unsafehtml')
1660
-
1661
- if exceptions:
1662
- exception_metadatas = [x.metadata for x in exceptions]
1663
- small_dict = {get_url(x['source'], from_str=True, short_name=True): get_short_name(x.get('exception')) for x in
1664
- exception_metadatas}
1665
- # if small_dict is empty dict, that's ok
1666
- df = pd.DataFrame(small_dict.items(), columns=['source', 'exception'])
1667
- df.index = df.index + 1
1668
- df.index.name = 'index'
1669
- exceptions_html = tabulate.tabulate(df, headers='keys', tablefmt='unsafehtml')
1670
- else:
1671
- exceptions_html = ''
1672
-
1673
- if metadatas and exceptions:
1674
- source_files_added = """\
1675
- <html>
1676
- <body>
1677
- <p>
1678
- {0} <br>
1679
- </p>
1680
- <div style="overflow-y: auto;height:400px">
1681
- {1}
1682
- {2}
1683
- </div>
1684
- </body>
1685
- </html>
1686
- """.format(source_label, source_files_added, exceptions_html)
1687
- elif metadatas:
1688
- source_files_added = """\
1689
- <html>
1690
- <body>
1691
- <p>
1692
- {0} <br>
1693
- </p>
1694
- <div style="overflow-y: auto;height:400px">
1695
- {1}
1696
- </div>
1697
- </body>
1698
- </html>
1699
- """.format(source_label, source_files_added)
1700
- elif exceptions_html:
1701
- source_files_added = """\
1702
- <html>
1703
- <body>
1704
- <p>
1705
- Exceptions: <br>
1706
- </p>
1707
- <div style="overflow-y: auto;height:400px">
1708
- {0}
1709
- </div>
1710
- </body>
1711
- </html>
1712
- """.format(exceptions_html)
1713
- else:
1714
- if adding_new:
1715
- source_files_added = "No New Sources"
1716
- else:
1717
- source_files_added = "No Sources"
1718
-
1719
- return source_files_added
1720
-
1721
-
1722
- def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=None, first_para=None,
1723
- text_limit=None, chunk=None, chunk_size=None,
1724
- user_path=None, db_type=None, load_db_if_exists=None,
1725
- n_jobs=None, verbose=None):
1726
- db = get_db(db1, langchain_mode, dbs=dbs)
1727
-
1728
- from gpt_langchain import make_db
1729
- db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
1730
- hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1731
- first_para=first_para, text_limit=text_limit, chunk=chunk,
1732
- chunk_size=chunk_size,
1733
- langchain_mode=langchain_mode,
1734
- user_path=user_path,
1735
- db_type=db_type,
1736
- load_db_if_exists=load_db_if_exists,
1737
- db=db,
1738
- n_jobs=n_jobs,
1739
- verbose=verbose)
1740
- # return only new sources with text saying such
1741
- return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_runner.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../gradio_runner.py
gradio_themes.py DELETED
@@ -1,183 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Iterable
4
-
5
- from gradio.themes.soft import Soft
6
- from gradio.themes import Color
7
- from gradio.themes.utils import colors, sizes, fonts
8
-
9
- h2o_yellow = Color(
10
- name="yellow",
11
- c50="#fffef2",
12
- c100="#fff9e6",
13
- c200="#ffecb3",
14
- c300="#ffe28c",
15
- c400="#ffd659",
16
- c500="#fec925",
17
- c600="#e6ac00",
18
- c700="#bf8f00",
19
- c800="#a67c00",
20
- c900="#664d00",
21
- c950="#403000",
22
- )
23
- h2o_gray = Color(
24
- name="gray",
25
- c50="#f8f8f8",
26
- c100="#e5e5e5",
27
- c200="#cccccc",
28
- c300="#b2b2b2",
29
- c400="#999999",
30
- c500="#7f7f7f",
31
- c600="#666666",
32
- c700="#4c4c4c",
33
- c800="#333333",
34
- c900="#191919",
35
- c950="#0d0d0d",
36
- )
37
-
38
-
39
- class H2oTheme(Soft):
40
- def __init__(
41
- self,
42
- *,
43
- primary_hue: colors.Color | str = h2o_yellow,
44
- secondary_hue: colors.Color | str = h2o_yellow,
45
- neutral_hue: colors.Color | str = h2o_gray,
46
- spacing_size: sizes.Size | str = sizes.spacing_md,
47
- radius_size: sizes.Size | str = sizes.radius_md,
48
- text_size: sizes.Size | str = sizes.text_lg,
49
- font: fonts.Font
50
- | str
51
- | Iterable[fonts.Font | str] = (
52
- fonts.GoogleFont("Montserrat"),
53
- "ui-sans-serif",
54
- "system-ui",
55
- "sans-serif",
56
- ),
57
- font_mono: fonts.Font
58
- | str
59
- | Iterable[fonts.Font | str] = (
60
- fonts.GoogleFont("IBM Plex Mono"),
61
- "ui-monospace",
62
- "Consolas",
63
- "monospace",
64
- ),
65
- ):
66
- super().__init__(
67
- primary_hue=primary_hue,
68
- secondary_hue=secondary_hue,
69
- neutral_hue=neutral_hue,
70
- spacing_size=spacing_size,
71
- radius_size=radius_size,
72
- text_size=text_size,
73
- font=font,
74
- font_mono=font_mono,
75
- )
76
- super().set(
77
- link_text_color="#3344DD",
78
- link_text_color_hover="#3344DD",
79
- link_text_color_visited="#3344DD",
80
- link_text_color_dark="#74abff",
81
- link_text_color_hover_dark="#a3c8ff",
82
- link_text_color_active_dark="#a3c8ff",
83
- link_text_color_visited_dark="#74abff",
84
- button_primary_text_color="*neutral_950",
85
- button_primary_text_color_dark="*neutral_950",
86
- button_primary_background_fill="*primary_500",
87
- button_primary_background_fill_dark="*primary_500",
88
- block_label_background_fill="*primary_500",
89
- block_label_background_fill_dark="*primary_500",
90
- block_label_text_color="*neutral_950",
91
- block_label_text_color_dark="*neutral_950",
92
- block_title_text_color="*neutral_950",
93
- block_title_text_color_dark="*neutral_950",
94
- block_background_fill_dark="*neutral_950",
95
- body_background_fill="*neutral_50",
96
- body_background_fill_dark="*neutral_900",
97
- background_fill_primary_dark="*block_background_fill",
98
- block_radius="0 0 8px 8px",
99
- checkbox_label_text_color_selected_dark='#000000',
100
- )
101
-
102
-
103
- class SoftTheme(Soft):
104
- def __init__(
105
- self,
106
- *,
107
- primary_hue: colors.Color | str = colors.indigo,
108
- secondary_hue: colors.Color | str = colors.indigo,
109
- neutral_hue: colors.Color | str = colors.gray,
110
- spacing_size: sizes.Size | str = sizes.spacing_md,
111
- radius_size: sizes.Size | str = sizes.radius_md,
112
- text_size: sizes.Size | str = sizes.text_md,
113
- font: fonts.Font
114
- | str
115
- | Iterable[fonts.Font | str] = (
116
- fonts.GoogleFont("Montserrat"),
117
- "ui-sans-serif",
118
- "system-ui",
119
- "sans-serif",
120
- ),
121
- font_mono: fonts.Font
122
- | str
123
- | Iterable[fonts.Font | str] = (
124
- fonts.GoogleFont("IBM Plex Mono"),
125
- "ui-monospace",
126
- "Consolas",
127
- "monospace",
128
- ),
129
- ):
130
- super().__init__(
131
- primary_hue=primary_hue,
132
- secondary_hue=secondary_hue,
133
- neutral_hue=neutral_hue,
134
- spacing_size=spacing_size,
135
- radius_size=radius_size,
136
- text_size=text_size,
137
- font=font,
138
- font_mono=font_mono,
139
- )
140
-
141
-
142
- h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
143
- ' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:' \
144
- '#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" ' \
145
- 'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.' \
146
- '47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.' \
147
- '82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10' \
148
- '.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"' \
149
- '/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.' \
150
- '76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69' \
151
- ',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.' \
152
- '85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.' \
153
- '69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30.' \
154
- '62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8.' \
155
- '62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-' \
156
- '12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"' \
157
- ' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,' \
158
- '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
159
-
160
-
161
- def get_h2o_title(title):
162
- return f"""<div style="display:flex; justify-content:center; margin-bottom:30px;">
163
- <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
164
- <h1 style="line-height:60px">{title}</h1>
165
- </div>
166
- <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
167
- <img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png></img>
168
- </div>
169
- """
170
-
171
-
172
- def get_simple_title(title):
173
- return f"""<h1 align="center"> {title}</h1>"""
174
-
175
-
176
- def get_dark_js():
177
- return """() => {
178
- if (document.querySelectorAll('.dark').length) {
179
- document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
180
- } else {
181
- document.querySelector('body').classList.add('dark');
182
- }
183
- }"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gradio_themes.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../gradio_themes.py
gradio_ui ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../gradio_ui
h2o-logo.svg DELETED
h2o-logo.svg ADDED
h2oai_pipeline.py DELETED
@@ -1,128 +0,0 @@
1
- from transformers import TextGenerationPipeline
2
- from transformers.pipelines.text_generation import ReturnType
3
-
4
- from stopping import get_stopping
5
- from prompter import Prompter
6
-
7
-
8
- class H2OTextGenerationPipeline(TextGenerationPipeline):
9
- def __init__(self, *args, debug=False, chat=False, stream_output=False,
10
- sanitize_bot_response=True,
11
- use_prompter=True, prompter=None, prompt_type=None,
12
- max_input_tokens=2048 - 256, **kwargs):
13
- """
14
- HF-like pipeline, but handle instruction prompting and stopping (for some models)
15
- :param args:
16
- :param debug:
17
- :param chat:
18
- :param stream_output:
19
- :param sanitize_bot_response:
20
- :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
21
- :param prompter: prompter, can pass if have already
22
- :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
23
- If use_prompter, then will make prompter and use it.
24
- :param max_input_tokens:
25
- :param kwargs:
26
- """
27
- super().__init__(*args, **kwargs)
28
- self.prompt_text = None
29
- self.use_prompter = use_prompter
30
- self.prompt_type = prompt_type
31
- self.prompter = prompter
32
- if self.use_prompter:
33
- if self.prompter is not None:
34
- assert self.prompter.prompt_type is not None
35
- else:
36
- self.prompter = Prompter(self.prompt_type, debug=debug, chat=chat, stream_output=stream_output)
37
- self.human = self.prompter.humanstr
38
- self.bot = self.prompter.botstr
39
- self.can_stop = True
40
- else:
41
- self.prompter = None
42
- self.human = None
43
- self.bot = None
44
- self.can_stop = False
45
- self.sanitize_bot_response = sanitize_bot_response
46
- self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
47
-
48
- def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
49
- data_point = dict(context='', instruction=prompt_text, input='')
50
- if self.prompter is not None:
51
- prompt_text = self.prompter.generate_prompt(data_point)
52
- self.prompt_text = prompt_text
53
- if handle_long_generation is None:
54
- # forces truncation of inputs to avoid critical failure
55
- handle_long_generation = 'hole'
56
- return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
57
- **generate_kwargs)
58
-
59
- def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
60
- records = super().postprocess(model_outputs, return_type=return_type,
61
- clean_up_tokenization_spaces=clean_up_tokenization_spaces)
62
- for rec in records:
63
- if self.use_prompter:
64
- outputs = rec['generated_text']
65
- outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
66
- sanitize_bot_response=self.sanitize_bot_response)
67
- elif self.bot and self.human:
68
- outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
69
- else:
70
- outputs = rec['generated_text']
71
- rec['generated_text'] = outputs
72
- return records
73
-
74
- def _forward(self, model_inputs, **generate_kwargs):
75
- if self.can_stop:
76
- stopping_criteria = get_stopping(self.prompt_type, self.tokenizer, self.device, human=self.human,
77
- bot=self.bot)
78
- generate_kwargs['stopping_criteria'] = stopping_criteria
79
- # return super()._forward(model_inputs, **generate_kwargs)
80
- return self.__forward(model_inputs, **generate_kwargs)
81
-
82
- # FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
83
- # FIXME: https://github.com/h2oai/h2ogpt/issues/172
84
- def __forward(self, model_inputs, **generate_kwargs):
85
- input_ids = model_inputs["input_ids"]
86
- attention_mask = model_inputs.get("attention_mask", None)
87
- # Allow empty prompts
88
- if input_ids.shape[1] == 0:
89
- input_ids = None
90
- attention_mask = None
91
- in_b = 1
92
- else:
93
- in_b = input_ids.shape[0]
94
- prompt_text = model_inputs.pop("prompt_text")
95
-
96
- ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
97
- ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
98
- # generate_kwargs = copy.deepcopy(generate_kwargs)
99
- prefix_length = generate_kwargs.pop("prefix_length", 0)
100
- if prefix_length > 0:
101
- has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
102
- "generation_config" in generate_kwargs
103
- and generate_kwargs["generation_config"].max_new_tokens is not None
104
- )
105
- if not has_max_new_tokens:
106
- generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
107
- generate_kwargs["max_length"] += prefix_length
108
- has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
109
- "generation_config" in generate_kwargs
110
- and generate_kwargs["generation_config"].min_new_tokens is not None
111
- )
112
- if not has_min_new_tokens and "min_length" in generate_kwargs:
113
- generate_kwargs["min_length"] += prefix_length
114
-
115
- # BS x SL
116
- generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
117
- out_b = generated_sequence.shape[0]
118
- if self.framework == "pt":
119
- generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
120
- elif self.framework == "tf":
121
- from transformers import is_tf_available
122
- if is_tf_available():
123
- import tensorflow as tf
124
- generated_sequence = tf.reshape(generated_sequence,
125
- (in_b, out_b // in_b, *generated_sequence.shape[1:]))
126
- else:
127
- raise ValueError("TF not avaialble.")
128
- return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
h2oai_pipeline.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../h2oai_pipeline.py
loaders.py DELETED
@@ -1,50 +0,0 @@
1
- def get_loaders(llama_type, model_name, reward_type):
2
- # NOTE: Some models need specific new prompt_type
3
- # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
4
- if llama_type:
5
- from transformers import LlamaForCausalLM, LlamaTokenizer
6
- model_loader = LlamaForCausalLM
7
- tokenizer_loader = LlamaTokenizer
8
- elif 'distilgpt2' in model_name.lower():
9
- from transformers import AutoModelForCausalLM, AutoTokenizer
10
- return AutoModelForCausalLM, AutoTokenizer
11
- elif 'gpt2' in model_name.lower():
12
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
13
- return GPT2LMHeadModel, GPT2Tokenizer
14
- elif 'mbart-' in model_name.lower():
15
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
16
- return MBartForConditionalGeneration, MBart50TokenizerFast
17
- elif 't5' == model_name.lower() or \
18
- 't5-' in model_name.lower() or \
19
- 'flan-' in model_name.lower():
20
- from transformers import AutoTokenizer, T5ForConditionalGeneration
21
- return T5ForConditionalGeneration, AutoTokenizer
22
- elif 'bigbird' in model_name:
23
- from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
24
- return BigBirdPegasusForConditionalGeneration, AutoTokenizer
25
- elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
26
- from transformers import pipeline
27
- return pipeline, "summarization"
28
- elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
29
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
30
- return AutoModelForSequenceClassification, AutoTokenizer
31
- else:
32
- from transformers import AutoTokenizer, AutoModelForCausalLM
33
- model_loader = AutoModelForCausalLM
34
- tokenizer_loader = AutoTokenizer
35
- return model_loader, tokenizer_loader
36
-
37
-
38
- def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
39
- tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
40
- local_files_only=local_files_only,
41
- resume_download=resume_download,
42
- use_auth_token=use_auth_token)
43
-
44
- tokenizer.pad_token_id = 0 # different from the eos token
45
- # when generating, we will use the logits of right-most token to predict the next token
46
- # so the padding should be on the left,
47
- # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
48
- tokenizer.padding_side = "left" # Allow batched inference
49
-
50
- return tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
loaders.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../loaders.py
prompter.py DELETED
@@ -1,576 +0,0 @@
1
- import time
2
- from enum import Enum
3
-
4
- non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
5
-
6
-
7
- class PromptType(Enum):
8
- plain = 0
9
- instruct = 1
10
- quality = 2
11
- human_bot = 3
12
- dai_faq = 4
13
- summarize = 5
14
- simple_instruct = 6
15
- instruct_vicuna = 7
16
- instruct_with_end = 8
17
- human_bot_orig = 9
18
- prompt_answer = 10
19
- open_assistant = 11
20
- wizard_lm = 12
21
- wizard_mega = 13
22
- instruct_vicuna2 = 14
23
- instruct_vicuna3 = 15
24
- wizard2 = 16
25
- wizard3 = 17
26
-
27
-
28
- prompt_type_to_model_name = {
29
- 'plain': [
30
- 'EleutherAI/gpt-j-6B',
31
- 'EleutherAI/pythia-6.9b',
32
- 'EleutherAI/pythia-12b',
33
- 'EleutherAI/pythia-12b-deduped',
34
- 'EleutherAI/gpt-neox-20b',
35
- 'openlm-research/open_llama_7b_700bt_preview',
36
- 'decapoda-research/llama-7b-hf',
37
- 'decapoda-research/llama-13b-hf',
38
- 'decapoda-research/llama-30b-hf',
39
- 'decapoda-research/llama-65b-hf',
40
- 'facebook/mbart-large-50-many-to-many-mmt',
41
- 'philschmid/bart-large-cnn-samsum',
42
- 'philschmid/flan-t5-base-samsum',
43
- 'gpt2',
44
- 'distilgpt2',
45
- 'mosaicml/mpt-7b-storywriter',
46
- 'mosaicml/mpt-7b-instruct', # internal code handles instruct
47
- 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
48
- 'gptj', # internally handles prompting
49
- 'llama', # plain, or need to choose prompt_type for given TheBloke model
50
- 'gpt4all_llama', # internally handles prompting
51
- ],
52
- 'prompt_answer': [
53
- 'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
54
- 'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
55
- 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
56
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
57
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
58
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
59
- ],
60
- 'instruct': [],
61
- 'instruct_with_end': ['databricks/dolly-v2-12b'],
62
- 'quality': [],
63
- 'human_bot': [
64
- 'h2oai/h2ogpt-oasst1-512-12b',
65
- 'h2oai/h2ogpt-oasst1-512-20b',
66
- 'h2oai/h2ogpt-oig-oasst1-256-6_9b',
67
- 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
68
- 'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
69
- 'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
70
- 'h2oai/h2ogpt-research-oasst1-512-30b',
71
- 'h2oai/h2ogpt-oasst1-falcon-40b',
72
- 'h2oai/h2ogpt-oig-oasst1-falcon-40b',
73
- ],
74
- 'dai_faq': [],
75
- 'summarize': [],
76
- 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
77
- 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
78
- 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
79
- "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
80
- "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
81
- "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
82
- }
83
-
84
- inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
85
- inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
86
-
87
- prompt_types_strings = []
88
- for p in PromptType:
89
- prompt_types_strings.extend([p.name])
90
-
91
- prompt_types = []
92
- for p in PromptType:
93
- prompt_types.extend([p.name, p.value, str(p.value)])
94
-
95
-
96
- def get_prompt(prompt_type, chat, context, reduced):
97
- if prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
98
- PromptType.plain.name]:
99
- promptA = promptB = PreInstruct = PreInput = PreResponse = ''
100
- terminate_response = []
101
- chat_sep = ''
102
- humanstr = ''
103
- botstr = ''
104
- elif prompt_type == 'simple_instruct':
105
- promptA = promptB = PreInstruct = PreInput = PreResponse = None
106
- terminate_response = []
107
- chat_sep = '\n'
108
- humanstr = ''
109
- botstr = ''
110
- elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
111
- PromptType.instruct.name] + [PromptType.instruct_with_end.value,
112
- str(PromptType.instruct_with_end.value),
113
- PromptType.instruct_with_end.name]:
114
- promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
115
- chat and reduced) else ''
116
- promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
117
- chat and reduced) else ''
118
-
119
- PreInstruct = """
120
- ### Instruction:
121
- """
122
-
123
- PreInput = """
124
- ### Input:
125
- """
126
-
127
- PreResponse = """
128
- ### Response:
129
- """
130
- if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
131
- PromptType.instruct_with_end.name]:
132
- terminate_response = ['### End']
133
- else:
134
- terminate_response = None
135
- chat_sep = '\n'
136
- humanstr = PreInstruct
137
- botstr = PreResponse
138
- elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
139
- PromptType.quality.name]:
140
- promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
141
- chat and reduced) else ''
142
- promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
143
- chat and reduced) else ''
144
-
145
- PreInstruct = """
146
- ### Instruction:
147
- """
148
-
149
- PreInput = """
150
- ### Input:
151
- """
152
-
153
- PreResponse = """
154
- ### Response:
155
- """
156
- terminate_response = None
157
- chat_sep = '\n'
158
- humanstr = PreInstruct # first thing human says
159
- botstr = PreResponse # first thing bot says
160
- elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
161
- PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
162
- str(PromptType.human_bot_orig.value),
163
- PromptType.human_bot_orig.name]:
164
- human = '<human>:'
165
- bot = "<bot>:"
166
- if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
167
- PromptType.human_bot.name]:
168
- preprompt = ''
169
- else:
170
- cur_date = time.strftime('%Y-%m-%d')
171
- cur_time = time.strftime('%H:%M:%S %p %Z')
172
-
173
- PRE_PROMPT = """\
174
- Current Date: {}
175
- Current Time: {}
176
-
177
- """
178
- preprompt = PRE_PROMPT.format(cur_date, cur_time)
179
- start = human
180
- promptB = promptA = '%s%s ' % (preprompt, start)
181
-
182
- PreInstruct = ""
183
-
184
- PreInput = None
185
-
186
- if reduced:
187
- # when making context, want it to appear as-if LLM generated, which starts with space after :
188
- PreResponse = bot + ' '
189
- else:
190
- # normally LLM adds space after this, because was how trained.
191
- # if add space here, non-unique tokenization will often make LLM produce wrong output
192
- PreResponse = bot
193
-
194
- terminate_response = [start, PreResponse]
195
- chat_sep = '\n'
196
- humanstr = human # tag before human talks
197
- botstr = bot # tag before bot talks
198
- elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
199
- PromptType.dai_faq.name]:
200
- promptA = ''
201
- promptB = 'Answer the following Driverless AI question.\n'
202
-
203
- PreInstruct = """
204
- ### Driverless AI frequently asked question:
205
- """
206
-
207
- PreInput = None
208
-
209
- PreResponse = """
210
- ### Driverless AI documentation answer:
211
- """
212
- terminate_response = ['\n\n']
213
- chat_sep = terminate_response
214
- humanstr = PreInstruct
215
- botstr = PreResponse
216
- elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
217
- PromptType.summarize.name]:
218
- promptA = promptB = PreInput = ''
219
- PreInstruct = '## Main Text\n\n'
220
- PreResponse = '\n\n## Summary\n\n'
221
- terminate_response = None
222
- chat_sep = '\n'
223
- humanstr = PreInstruct
224
- botstr = PreResponse
225
- elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
226
- PromptType.instruct_vicuna.name]:
227
- promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
228
- "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
229
- chat and reduced) else ''
230
-
231
- PreInstruct = """
232
- ### Human:
233
- """
234
-
235
- PreInput = None
236
-
237
- PreResponse = """
238
- ### Assistant:
239
- """
240
- terminate_response = [
241
- '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
242
- chat_sep = '\n'
243
- humanstr = PreInstruct
244
- botstr = PreResponse
245
- elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
246
- PromptType.prompt_answer.name]:
247
- preprompt = ''
248
- prompt_tokens = "<|prompt|>"
249
- answer_tokens = "<|answer|>"
250
- start = prompt_tokens
251
- promptB = promptA = '%s%s' % (preprompt, start)
252
- PreInstruct = ""
253
- PreInput = None
254
- PreResponse = answer_tokens
255
- eos = '<|endoftext|>' # neox eos
256
- terminate_response = [start, PreResponse, eos]
257
- chat_sep = eos
258
- humanstr = prompt_tokens
259
- botstr = answer_tokens
260
- elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
261
- PromptType.open_assistant.name]:
262
- # From added_tokens.json
263
- preprompt = ''
264
- prompt_tokens = "<|prompter|>"
265
- answer_tokens = "<|assistant|>"
266
- start = prompt_tokens
267
- promptB = promptA = '%s%s' % (preprompt, start)
268
- PreInstruct = ""
269
- PreInput = None
270
- PreResponse = answer_tokens
271
- pend = "<|prefix_end|>"
272
- eos = "</s>"
273
- terminate_response = [start, PreResponse, pend, eos]
274
- chat_sep = eos
275
- humanstr = prompt_tokens
276
- botstr = answer_tokens
277
- elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
278
- PromptType.wizard_lm.name]:
279
- # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
280
- preprompt = ''
281
- start = ''
282
- promptB = promptA = '%s%s' % (preprompt, start)
283
- PreInstruct = ""
284
- PreInput = None
285
- PreResponse = "\n\n### Response\n"
286
- eos = "</s>"
287
- terminate_response = [PreResponse, eos]
288
- chat_sep = eos
289
- humanstr = promptA
290
- botstr = PreResponse
291
- elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
292
- PromptType.wizard_mega.name]:
293
- preprompt = ''
294
- start = ''
295
- promptB = promptA = '%s%s' % (preprompt, start)
296
- PreInstruct = """
297
- ### Instruction:
298
- """
299
- PreInput = None
300
- PreResponse = """
301
- ### Assistant:
302
- """
303
- terminate_response = [PreResponse]
304
- chat_sep = '\n'
305
- humanstr = PreInstruct
306
- botstr = PreResponse
307
- elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
308
- PromptType.instruct_vicuna2.name]:
309
- promptA = promptB = "" if not (
310
- chat and reduced) else ''
311
-
312
- PreInstruct = """
313
- HUMAN:
314
- """
315
-
316
- PreInput = None
317
-
318
- PreResponse = """
319
- ASSISTANT:
320
- """
321
- terminate_response = [
322
- 'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
323
- chat_sep = '\n'
324
- humanstr = PreInstruct
325
- botstr = PreResponse
326
- elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
327
- PromptType.instruct_vicuna3.name]:
328
- promptA = promptB = "" if not (
329
- chat and reduced) else ''
330
-
331
- PreInstruct = """
332
- ### User:
333
- """
334
-
335
- PreInput = None
336
-
337
- PreResponse = """
338
- ### Assistant:
339
- """
340
- terminate_response = [
341
- '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
342
- chat_sep = '\n'
343
- humanstr = PreInstruct
344
- botstr = PreResponse
345
- elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
346
- PromptType.wizard2.name]:
347
- # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
348
- preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
349
- start = ''
350
- promptB = promptA = '%s%s' % (preprompt, start)
351
- PreInstruct = """
352
- ### Instruction:
353
- """
354
- PreInput = None
355
- PreResponse = """
356
- ### Response:
357
- """
358
- terminate_response = [PreResponse]
359
- chat_sep = '\n'
360
- humanstr = PreInstruct
361
- botstr = PreResponse
362
- elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
363
- PromptType.wizard3.name]:
364
- # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
365
- preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
366
- start = ''
367
- promptB = promptA = '%s%s' % (preprompt, start)
368
- PreInstruct = """USER: """
369
- PreInput = None
370
- PreResponse = """ASSISTANT: """
371
- terminate_response = [PreResponse]
372
- chat_sep = '\n'
373
- humanstr = PreInstruct
374
- botstr = PreResponse
375
-
376
- else:
377
- raise RuntimeError("No such prompt_type=%s" % prompt_type)
378
-
379
- return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
380
-
381
-
382
- def generate_prompt(data_point, prompt_type, chat, reduced):
383
- context = data_point.get('context')
384
- if context is None:
385
- context = ''
386
- instruction = data_point.get('instruction')
387
- input = data_point.get('input')
388
- output = data_point.get('output')
389
- prompt_type = data_point.get('prompt_type', prompt_type)
390
- assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
391
- promptA, promptB, PreInstruct, PreInput, PreResponse, \
392
- terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, chat, context, reduced)
393
-
394
- prompt = context if not reduced else ''
395
-
396
- if input and promptA:
397
- prompt += f"""{promptA}"""
398
- elif promptB:
399
- prompt += f"""{promptB}"""
400
-
401
- if instruction and PreInstruct is not None and input and PreInput is not None:
402
- prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
403
- prompt = inject_newline(prompt_type, prompt)
404
- elif instruction and input and PreInstruct is None and PreInput is not None:
405
- prompt += f"""{PreInput}{instruction}
406
- {input}"""
407
- prompt = inject_newline(prompt_type, prompt)
408
- elif input and instruction and PreInput is None and PreInstruct is not None:
409
- prompt += f"""{PreInstruct}{instruction}
410
- {input}"""
411
- prompt = inject_newline(prompt_type, prompt)
412
- elif instruction and PreInstruct is not None:
413
- prompt += f"""{PreInstruct}{instruction}"""
414
- prompt = inject_newline(prompt_type, prompt)
415
- elif input and PreInput is not None:
416
- prompt += f"""{PreInput}{input}"""
417
- prompt = inject_newline(prompt_type, prompt)
418
- elif input and instruction and PreInput is not None:
419
- prompt += f"""{PreInput}{instruction}{input}"""
420
- prompt = inject_newline(prompt_type, prompt)
421
- elif input and instruction and PreInstruct is not None:
422
- prompt += f"""{PreInstruct}{instruction}{input}"""
423
- prompt = inject_newline(prompt_type, prompt)
424
- elif input and instruction:
425
- # i.e. for simple_instruct
426
- prompt += f"""{instruction}: {input}"""
427
- prompt = inject_newline(prompt_type, prompt)
428
- elif input:
429
- prompt += f"""{input}"""
430
- prompt = inject_newline(prompt_type, prompt)
431
- elif instruction:
432
- prompt += f"""{instruction}"""
433
- prompt = inject_newline(prompt_type, prompt)
434
-
435
- if PreResponse is not None:
436
- prompt += f"""{PreResponse}"""
437
- pre_response = PreResponse # Don't use strip
438
- else:
439
- pre_response = ''
440
-
441
- if output:
442
- prompt += f"""{output}"""
443
-
444
- return prompt, pre_response, terminate_response, chat_sep
445
-
446
-
447
- def inject_newline(prompt_type, prompt):
448
- if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
449
- # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
450
- prompt += '\n'
451
- return prompt
452
-
453
-
454
- class Prompter(object):
455
- def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
456
- allowed_repeat_line_length=10):
457
- self.prompt_type = prompt_type
458
- data_point = dict(instruction='', input='', output='')
459
- _, self.pre_response, self.terminate_response, self.chat_sep = \
460
- generate_prompt(data_point, prompt_type, chat, False)
461
- self.debug = debug
462
- self.chat = chat
463
- self.stream_output = stream_output
464
- self.repeat_penalty = repeat_penalty
465
- self.allowed_repeat_line_length = allowed_repeat_line_length
466
- self.prompt = None
467
- context = "" # not for chat context
468
- reduced = False # not for chat context
469
- self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
470
- self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
471
- get_prompt(prompt_type, chat, context, reduced)
472
-
473
- def generate_prompt(self, data_point):
474
- reduced = False
475
- prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
476
- if self.debug:
477
- print("prompt: ", prompt, flush=True)
478
- self.prompt = prompt
479
- return prompt
480
-
481
- def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
482
- if isinstance(outputs, str):
483
- outputs = [outputs]
484
- if self.debug:
485
- print("output:\n", '\n\n'.join(outputs), flush=True)
486
- if prompt is not None:
487
- self.prompt = prompt
488
-
489
- def clean_response(response):
490
- meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
491
- for word in meaningless_words:
492
- response = response.replace(word, "")
493
- if sanitize_bot_response:
494
- from better_profanity import profanity
495
- response = profanity.censor(response)
496
- response = response.strip("\n")
497
- return response
498
-
499
- def clean_repeats(response):
500
- lines = response.split('\n')
501
- new_lines = []
502
- [new_lines.append(line) for line in lines if
503
- line not in new_lines or len(line) < self.allowed_repeat_line_length]
504
- if self.debug and len(lines) != len(new_lines):
505
- print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
506
- response = '\n'.join(new_lines)
507
- return response
508
-
509
- multi_output = len(outputs) > 1
510
-
511
- for oi, output in enumerate(outputs):
512
- if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
513
- output = clean_response(output)
514
- elif prompt is None:
515
- # then use most basic parsing like pipeline
516
- if self.botstr in output:
517
- if self.humanstr:
518
- output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
519
- else:
520
- # i.e. use after bot but only up to next bot
521
- output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
522
- else:
523
- # output = clean_response(output.strip())
524
- # assume just not printed yet
525
- output = ""
526
- else:
527
- # find first instance of prereponse
528
- # prompt sometimes has odd characters, that mutate length,
529
- # so can't go by length alone
530
- if self.pre_response:
531
- outputi = output.find(prompt)
532
- if outputi >= 0:
533
- output = output[outputi + len(prompt):]
534
- allow_terminate = True
535
- else:
536
- # subtraction is risky due to space offsets sometimes, so only do if necessary
537
- output = output[len(prompt) - len(self.pre_response):]
538
- # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
539
- if self.pre_response in output:
540
- output = output.split(self.pre_response)[1]
541
- allow_terminate = True
542
- else:
543
- if output:
544
- print("Failure of parsing or not enough output yet: %s" % output, flush=True)
545
- allow_terminate = False
546
- else:
547
- allow_terminate = True
548
- output = output[len(prompt):]
549
- # clean after subtract prompt out, so correct removal of pre_response
550
- output = clean_response(output).strip()
551
- if self.repeat_penalty:
552
- output = clean_repeats(output).strip()
553
- if self.terminate_response and allow_terminate:
554
- finds = []
555
- for term in self.terminate_response:
556
- finds.append(output.find(term))
557
- finds = [x for x in finds if x >= 0]
558
- if len(finds) > 0:
559
- termi = finds[0]
560
- output = output[:termi].strip()
561
- else:
562
- output = output.strip()
563
- else:
564
- output = output.strip()
565
- if multi_output:
566
- # prefix with output counter
567
- output = "\n=========== Output %d\n\n" % (1 + oi) + output
568
- if oi > 0:
569
- # post fix outputs with seperator
570
- output += '\n'
571
- outputs[oi] = output
572
- # join all outputs, only one extra new line between outputs
573
- output = '\n'.join(outputs)
574
- if self.debug:
575
- print("outputclean:\n", '\n\n'.join(outputs), flush=True)
576
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompter.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../prompter.py
requirements.txt CHANGED
@@ -1,100 +0,0 @@
1
- # for generate (gradio server) and finetune
2
- datasets==2.12.0
3
- sentencepiece==0.1.97
4
- gradio==3.31.0
5
- huggingface_hub==0.14.1
6
- appdirs==1.4.4
7
- fire==0.5.0
8
- docutils==0.19
9
- torch==2.0.1
10
- evaluate==0.4.0
11
- rouge_score==0.1.2
12
- sacrebleu==2.3.1
13
- scikit-learn==1.2.2
14
- alt-profanity-check==1.2.2
15
- better-profanity==0.6.1
16
- numpy==1.24.2
17
- pandas==2.0.0
18
- matplotlib==3.7.1
19
- loralib==0.1.1
20
- bitsandbytes==0.39.0
21
- accelerate==0.19.0
22
- git+https://github.com/huggingface/peft.git@3714aa2fff158fdfa637b2b65952580801d890b2
23
- transformers==4.28.1
24
- tokenizers==0.13.3
25
- APScheduler==3.10.1
26
-
27
- # optional for generate
28
- pynvml==11.5.0
29
- psutil==5.9.4
30
- boto3==1.26.101
31
- botocore==1.29.101
32
-
33
- # optional for finetune
34
- tensorboard==2.12.1
35
- neptune==1.1.1
36
-
37
- # for gradio client
38
- gradio_client==0.2.5
39
- beautifulsoup4==4.12.2
40
- markdown==3.4.1
41
-
42
- # data and testing
43
- pytest==7.2.2
44
- pytest-xdist==3.2.1
45
- nltk==3.8.1
46
- textstat==0.7.3
47
- pandoc==2.3
48
- #pypandoc==1.11
49
- pypandoc_binary==1.11
50
- openpyxl==3.1.2
51
- lm_dataformat==0.0.20
52
- bioc==2.0
53
-
54
- # falcon
55
- einops==0.6.1
56
- # optional for chat with PDF
57
- langchain==0.0.183
58
- pypdf==3.8.1
59
- tiktoken==0.3.3
60
- # avoid textract, requires old six
61
- #textract==1.6.5
62
-
63
- # for HF embeddings
64
- sentence_transformers==2.2.2
65
- # for OpenAI embeddings (requires key)
66
- openai==0.27.6
67
-
68
- # local vector db
69
- chromadb==0.3.25
70
- # server vector db
71
- #pymilvus==2.2.8
72
-
73
- # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
74
- # unstructured==0.6.6
75
-
76
- # strong support for images
77
- # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
78
- unstructured[local-inference]==0.6.6
79
- #pdf2image==1.16.3
80
- #pytesseract==0.3.10
81
- pillow
82
-
83
- pdfminer.six==20221105
84
- urllib3==1.26.6
85
- requests_file==1.5.1
86
-
87
- #pdf2image==1.16.3
88
- #pytesseract==0.3.10
89
- tabulate==0.9.0
90
- # FYI pandoc already part of requirements.txt
91
-
92
- # JSONLoader, but makes some trouble for some users
93
- # jq==1.4.1
94
-
95
- # to check licenses
96
- # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
97
- pip-licenses==4.3.0
98
-
99
- # weaviate vector db
100
- weaviate-client==3.19.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stopping.py DELETED
@@ -1,72 +0,0 @@
1
- import torch
2
- from transformers import StoppingCriteria, StoppingCriteriaList
3
-
4
- from prompter import PromptType
5
-
6
-
7
- class StoppingCriteriaSub(StoppingCriteria):
8
-
9
- def __init__(self, stops=[], encounters=[], device="cuda"):
10
- super().__init__()
11
- assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
12
- self.encounters = encounters
13
- self.stops = [stop.to(device) for stop in stops]
14
- self.num_stops = [0] * len(stops)
15
-
16
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
- for stopi, stop in enumerate(self.stops):
18
- if torch.all((stop == input_ids[0][-len(stop):])).item():
19
- self.num_stops[stopi] += 1
20
- if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
21
- # print("Stopped", flush=True)
22
- return True
23
- # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
24
- # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
25
- return False
26
-
27
-
28
- def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
29
- if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
30
- if prompt_type == PromptType.human_bot.name:
31
- # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
32
- # stopping only starts once output is beyond prompt
33
- # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
34
- stop_words = [human, bot, '\n' + human, '\n' + bot]
35
- encounters = [1, 2]
36
- elif prompt_type == PromptType.instruct_vicuna.name:
37
- # even below is not enough, generic strings and many ways to encode
38
- stop_words = [
39
- '### Human:',
40
- """
41
- ### Human:""",
42
- """
43
- ### Human:
44
- """,
45
- '### Assistant:',
46
- """
47
- ### Assistant:""",
48
- """
49
- ### Assistant:
50
- """,
51
- ]
52
- encounters = [1, 2]
53
- else:
54
- # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
55
- stop_words = ['### End']
56
- encounters = [1]
57
- stop_words_ids = [
58
- tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
59
- # handle single token case
60
- stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
61
- stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
62
- # avoid padding in front of tokens
63
- if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
64
- stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
65
- # handle fake \n added
66
- stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
67
- # build stopper
68
- stopping_criteria = StoppingCriteriaList(
69
- [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
70
- else:
71
- stopping_criteria = StoppingCriteriaList()
72
- return stopping_criteria
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stopping.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../stopping.py
utils.py DELETED
@@ -1,843 +0,0 @@
1
- import contextlib
2
- import functools
3
- import hashlib
4
- import inspect
5
- import os
6
- import gc
7
- import pathlib
8
- import random
9
- import shutil
10
- import subprocess
11
- import sys
12
- import threading
13
- import time
14
- import traceback
15
- import zipfile
16
- from datetime import datetime
17
- import filelock
18
- import requests, uuid
19
- from typing import Tuple, Callable, Dict
20
- from tqdm.auto import tqdm
21
- from joblib import Parallel
22
- from concurrent.futures import ProcessPoolExecutor
23
- import numpy as np
24
- import pandas as pd
25
-
26
-
27
- def set_seed(seed: int):
28
- """
29
- Sets the seed of the entire notebook so results are the same every time we run.
30
- This is for REPRODUCIBILITY.
31
- """
32
- import torch
33
- np.random.seed(seed)
34
- random_state = np.random.RandomState(seed)
35
- random.seed(seed)
36
- torch.manual_seed(seed)
37
- torch.cuda.manual_seed(seed)
38
- torch.backends.cudnn.deterministic = True
39
- torch.backends.cudnn.benchmark = False
40
- os.environ['PYTHONHASHSEED'] = str(seed)
41
- return random_state
42
-
43
-
44
- def flatten_list(lis):
45
- """Given a list, possibly nested to any level, return it flattened."""
46
- new_lis = []
47
- for item in lis:
48
- if type(item) == type([]):
49
- new_lis.extend(flatten_list(item))
50
- else:
51
- new_lis.append(item)
52
- return new_lis
53
-
54
-
55
- def clear_torch_cache():
56
- import torch
57
- if torch.cuda.is_available():
58
- torch.cuda.empty_cache()
59
- torch.cuda.ipc_collect()
60
- gc.collect()
61
-
62
-
63
- def ping():
64
- try:
65
- print('Ping: %s' % str(datetime.now()), flush=True)
66
- except AttributeError:
67
- # some programs wrap print and will fail with flush passed
68
- pass
69
-
70
-
71
- def get_torch_allocated():
72
- import torch
73
- return torch.cuda.memory_allocated()
74
-
75
-
76
- def get_device():
77
- import torch
78
- if torch.cuda.is_available():
79
- device = "cuda"
80
- else:
81
- device = "cpu"
82
-
83
- return device
84
-
85
-
86
- def system_info():
87
- import psutil
88
-
89
- system = {}
90
- # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
91
- # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
92
- temps = psutil.sensors_temperatures(fahrenheit=False)
93
- if 'coretemp' in temps:
94
- coretemp = temps['coretemp']
95
- temp_dict = {k.label: k.current for k in coretemp}
96
- for k, v in temp_dict.items():
97
- system['CPU_C/%s' % k] = v
98
-
99
- # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
100
- from pynvml.smi import nvidia_smi
101
- nvsmi = nvidia_smi.getInstance()
102
-
103
- gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
104
- enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
105
- for k, v in gpu_power_dict.items():
106
- system['GPU_W/%s' % k] = v
107
-
108
- gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
109
- enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
110
- for k, v in gpu_temp_dict.items():
111
- system['GPU_C/%s' % k] = v
112
-
113
- gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
114
- enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
115
- gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
116
- enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
117
- gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
118
- for k, v in gpu_memory_frac_dict.items():
119
- system[f'GPU_M/%s' % k] = v
120
-
121
- system['hash'] = get_githash()
122
-
123
- return system
124
-
125
-
126
- def system_info_print():
127
- try:
128
- df = pd.DataFrame.from_dict(system_info(), orient='index')
129
- # avoid slamming GPUs
130
- time.sleep(1)
131
- return df.to_markdown()
132
- except Exception as e:
133
- return "Error: %s" % str(e)
134
-
135
-
136
- def zip_data(root_dirs=None, zip_file=None, base_dir='./', fail_any_exception=False):
137
- try:
138
- return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
139
- except Exception as e:
140
- traceback.print_exc()
141
- print('Exception in zipping: %s' % str(e))
142
- if not fail_any_exception:
143
- raise
144
-
145
-
146
- def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
147
- if isinstance(root_dirs, str):
148
- root_dirs = [root_dirs]
149
- if zip_file is None:
150
- datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
151
- host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
152
- zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
153
- assert root_dirs is not None
154
- if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname(zip_file):
155
- os.makedirs(os.path.dirname(zip_file), exist_ok=True)
156
- with zipfile.ZipFile(zip_file, "w") as expt_zip:
157
- for root_dir in root_dirs:
158
- if root_dir is None:
159
- continue
160
- for root, d, files in os.walk(root_dir):
161
- for file in files:
162
- file_to_archive = os.path.join(root, file)
163
- assert os.path.exists(file_to_archive)
164
- path_to_archive = os.path.relpath(file_to_archive, base_dir)
165
- expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
166
- return zip_file, zip_file
167
-
168
-
169
- def save_generate_output(output=None, base_model=None, save_dir=None):
170
- try:
171
- return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
172
- except Exception as e:
173
- traceback.print_exc()
174
- print('Exception in saving: %s' % str(e))
175
-
176
-
177
- def _save_generate_output(output=None, base_model=None, save_dir=None):
178
- """
179
- Save conversation to .json, row by row.
180
- json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
181
- Appends if file exists
182
- """
183
- assert save_dir, "save_dir must be provided"
184
- if os.path.exists(save_dir) and not os.path.isdir(save_dir):
185
- raise RuntimeError("save_dir already exists and is not a directory!")
186
- os.makedirs(save_dir, exist_ok=True)
187
- import json
188
- if output[-10:] == '\n\n<human>:':
189
- # remove trailing <human>:
190
- output = output[:-10]
191
- with filelock.FileLock("save_dir.lock"):
192
- # lock logging in case have concurrency
193
- with open(os.path.join(save_dir, "history.json"), "a") as f:
194
- # just add [ at start, and ] at end, and have proper JSON dataset
195
- f.write(
196
- " " + json.dumps(
197
- dict(text=output, time=time.ctime(), base_model=base_model)
198
- ) + ",\n"
199
- )
200
-
201
-
202
- def s3up(filename):
203
- try:
204
- return _s3up(filename)
205
- except Exception as e:
206
- traceback.print_exc()
207
- print('Exception for file %s in s3up: %s' % (filename, str(e)))
208
- return "Failed to upload %s: Error: %s" % (filename, str(e))
209
-
210
-
211
- def _s3up(filename):
212
- import boto3
213
-
214
- aws_access_key_id = os.getenv('AWS_SERVER_PUBLIC_KEY')
215
- aws_secret_access_key = os.getenv('AWS_SERVER_SECRET_KEY')
216
- bucket = os.getenv('AWS_BUCKET')
217
- assert aws_access_key_id, "Set AWS key"
218
- assert aws_secret_access_key, "Set AWS secret"
219
- assert bucket, "Set AWS Bucket"
220
-
221
- s3 = boto3.client('s3',
222
- aws_access_key_id=os.getenv('AWS_SERVER_PUBLIC_KEY'),
223
- aws_secret_access_key=os.getenv('AWS_SERVER_SECRET_KEY'),
224
- )
225
- ret = s3.upload_file(
226
- Filename=filename,
227
- Bucket=os.getenv('AWS_BUCKET'),
228
- Key=filename,
229
- )
230
- if ret in [None, '']:
231
- return "Successfully uploaded %s" % filename
232
-
233
-
234
- def get_githash():
235
- try:
236
- githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
237
- except:
238
- githash = ''
239
- return githash
240
-
241
-
242
- def copy_code(run_id):
243
- """
244
- copy code to track changes
245
- :param run_id:
246
- :return:
247
- """
248
- rnd_num = str(random.randint(0, 2 ** 31))
249
- run_id = 'run_' + str(run_id)
250
- os.makedirs(run_id, exist_ok=True)
251
- me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
252
- me_file = os.path.basename(__file__)
253
- new_me = os.path.join(run_id, me_file + '_' + get_githash())
254
- if os.path.isfile(new_me):
255
- new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
256
- shutil.copy(me_full, new_me)
257
- else:
258
- shutil.copy(me_full, new_me)
259
-
260
-
261
- class NullContext(threading.local):
262
- """No-op context manager, executes block without doing any additional processing.
263
-
264
- Used as a stand-in if a particular block of code is only sometimes
265
- used with a normal context manager:
266
- """
267
-
268
- def __init__(self, *args, **kwargs):
269
- pass
270
-
271
- def __enter__(self):
272
- return self
273
-
274
- def __exit__(self, exc_type, exc_value, exc_traceback):
275
- self.finally_act()
276
-
277
- def finally_act(self):
278
- pass
279
-
280
-
281
- def wrapped_partial(func, *args, **kwargs):
282
- """
283
- Give partial properties of normal function, like __name__ attribute etc.
284
- :param func:
285
- :param args:
286
- :param kwargs:
287
- :return:
288
- """
289
- partial_func = functools.partial(func, *args, **kwargs)
290
- functools.update_wrapper(partial_func, func)
291
- return partial_func
292
-
293
-
294
- class ThreadException(Exception):
295
- pass
296
-
297
-
298
- class EThread(threading.Thread):
299
- # Function that raises the custom exception
300
- def __init__(self, group=None, target=None, name=None,
301
- args=(), kwargs=None, *, daemon=None, streamer=None, bucket=None):
302
- self.bucket = bucket
303
- self.streamer = streamer
304
- self.exc = None
305
- self._return = None
306
- super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
307
-
308
- def run(self):
309
- # Variable that stores the exception, if raised by someFunction
310
- try:
311
- if self._target is not None:
312
- self._return = self._target(*self._args, **self._kwargs)
313
- except BaseException as e:
314
- print("thread exception: %s" % str(sys.exc_info()))
315
- self.bucket.put(sys.exc_info())
316
- self.exc = e
317
- if self.streamer:
318
- print("make stop: %s" % str(sys.exc_info()), flush=True)
319
- self.streamer.do_stop = True
320
- finally:
321
- # Avoid a refcycle if the thread is running a function with
322
- # an argument that has a member that points to the thread.
323
- del self._target, self._args, self._kwargs
324
-
325
- def join(self, timeout=None):
326
- threading.Thread.join(self)
327
- # Since join() returns in caller thread
328
- # we re-raise the caught exception
329
- # if any was caught
330
- if self.exc:
331
- raise self.exc
332
- return self._return
333
-
334
-
335
- def import_matplotlib():
336
- import matplotlib
337
- matplotlib.use('agg')
338
- # KEEP THESE HERE! START
339
- import matplotlib.pyplot as plt
340
- import pandas as pd
341
- # to avoid dlopen deadlock in fork
342
- import pandas.core.computation.expressions as pd_expressions
343
- import pandas._libs.groupby as pd_libgroupby
344
- import pandas._libs.reduction as pd_libreduction
345
- import pandas.core.algorithms as pd_algorithms
346
- import pandas.core.common as pd_com
347
- import numpy as np
348
- # KEEP THESE HERE! END
349
-
350
-
351
- def get_sha(value):
352
- return hashlib.md5(str(value).encode('utf-8')).hexdigest()
353
-
354
-
355
- def sanitize_filename(name):
356
- """
357
- Sanitize file *base* names.
358
- :param name: name to sanitize
359
- :return:
360
- """
361
- bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^']
362
- for char in bad_chars:
363
- name = name.replace(char, "_")
364
-
365
- length = len(name)
366
- file_length_limit = 250 # bit smaller than 256 for safety
367
- sha_length = 32
368
- real_length_limit = file_length_limit - (sha_length + 2)
369
- if length > file_length_limit:
370
- sha = get_sha(name)
371
- half_real_length_limit = max(1, int(real_length_limit / 2))
372
- name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length]
373
-
374
- return name
375
-
376
-
377
- def shutil_rmtree(*args, **kwargs):
378
- return shutil.rmtree(*args, **kwargs)
379
-
380
-
381
- def remove(path: str):
382
- try:
383
- if path is not None and os.path.exists(path):
384
- if os.path.isdir(path):
385
- shutil_rmtree(path, ignore_errors=True)
386
- else:
387
- with contextlib.suppress(FileNotFoundError):
388
- os.remove(path)
389
- except:
390
- pass
391
-
392
-
393
- def makedirs(path, exist_ok=True):
394
- """
395
- Avoid some inefficiency in os.makedirs()
396
- :param path:
397
- :param exist_ok:
398
- :return:
399
- """
400
- if os.path.isdir(path) and os.path.exists(path):
401
- assert exist_ok, "Path already exists"
402
- return path
403
- os.makedirs(path, exist_ok=exist_ok)
404
-
405
-
406
- def atomic_move_simple(src, dst):
407
- try:
408
- shutil.move(src, dst)
409
- except (shutil.Error, FileExistsError):
410
- pass
411
- remove(src)
412
-
413
-
414
- def download_simple(url, dest=None, print_func=None):
415
- if print_func is not None:
416
- print_func("BEGIN get url %s" % str(url))
417
- if url.startswith("file://"):
418
- from requests_file import FileAdapter
419
- s = requests.Session()
420
- s.mount('file://', FileAdapter())
421
- url_data = s.get(url, stream=True)
422
- else:
423
- url_data = requests.get(url, stream=True)
424
- if dest is None:
425
- dest = os.path.basename(url)
426
- if url_data.status_code != requests.codes.ok:
427
- msg = "Cannot get url %s, code: %s, reason: %s" % (
428
- str(url),
429
- str(url_data.status_code),
430
- str(url_data.reason),
431
- )
432
- raise requests.exceptions.RequestException(msg)
433
- url_data.raw.decode_content = True
434
- makedirs(os.path.dirname(dest), exist_ok=True)
435
- uuid_tmp = str(uuid.uuid4())[:6]
436
- dest_tmp = dest + "_dl_" + uuid_tmp + ".tmp"
437
- with open(dest_tmp, "wb") as f:
438
- shutil.copyfileobj(url_data.raw, f)
439
- atomic_move_simple(dest_tmp, dest)
440
- if print_func is not None:
441
- print_func("END get url %s" % str(url))
442
-
443
-
444
- def download(url, dest=None, dest_path=None):
445
- if dest_path is not None:
446
- dest = os.path.join(dest_path, os.path.basename(url))
447
- if os.path.isfile(dest):
448
- print("already downloaded %s -> %s" % (url, dest))
449
- return dest
450
- elif dest is not None:
451
- if os.path.exists(dest):
452
- print("already downloaded %s -> %s" % (url, dest))
453
- return dest
454
- else:
455
- uuid_tmp = "dl2_" + str(uuid.uuid4())[:6]
456
- dest = uuid_tmp + os.path.basename(url)
457
-
458
- print("downloading %s to %s" % (url, dest))
459
-
460
- if url.startswith("file://"):
461
- from requests_file import FileAdapter
462
- s = requests.Session()
463
- s.mount('file://', FileAdapter())
464
- url_data = s.get(url, stream=True)
465
- else:
466
- url_data = requests.get(url, stream=True)
467
-
468
- if url_data.status_code != requests.codes.ok:
469
- msg = "Cannot get url %s, code: %s, reason: %s" % (
470
- str(url), str(url_data.status_code), str(url_data.reason))
471
- raise requests.exceptions.RequestException(msg)
472
- url_data.raw.decode_content = True
473
- dirname = os.path.dirname(dest)
474
- if dirname != "" and not os.path.isdir(dirname):
475
- makedirs(os.path.dirname(dest), exist_ok=True)
476
- uuid_tmp = "dl3_" + str(uuid.uuid4())[:6]
477
- dest_tmp = dest + "_" + uuid_tmp + ".tmp"
478
- with open(dest_tmp, 'wb') as f:
479
- shutil.copyfileobj(url_data.raw, f)
480
- try:
481
- shutil.move(dest_tmp, dest)
482
- except FileExistsError:
483
- pass
484
- remove(dest_tmp)
485
- return dest
486
-
487
-
488
- def get_url(x, from_str=False, short_name=False):
489
- if not from_str:
490
- source = x.metadata['source']
491
- else:
492
- source = x
493
- if short_name:
494
- source_name = get_short_name(source)
495
- else:
496
- source_name = source
497
- if source.startswith('http://') or source.startswith('https://'):
498
- return """<a href="%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
499
- source, source_name)
500
- else:
501
- return """<a href="file/%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
502
- source, source_name)
503
-
504
-
505
- def get_short_name(name, maxl=50):
506
- if name is None:
507
- return ''
508
- length = len(name)
509
- if length > maxl:
510
- allow_length = maxl - 3
511
- half_allowed = max(1, int(allow_length / 2))
512
- name = name[0:half_allowed] + "..." + name[length - half_allowed:length]
513
- return name
514
-
515
-
516
- def cuda_vis_check(total_gpus):
517
- """Helper function to count GPUs by environment variable
518
- Stolen from Jon's h2o4gpu utils
519
- """
520
- cudavis = os.getenv("CUDA_VISIBLE_DEVICES")
521
- which_gpus = []
522
- if cudavis is not None:
523
- # prune away white-space, non-numerics,
524
- # except commas for simple checking
525
- cudavis = "".join(cudavis.split())
526
- import re
527
- cudavis = re.sub("[^0-9,]", "", cudavis)
528
-
529
- lencudavis = len(cudavis)
530
- if lencudavis == 0:
531
- total_gpus = 0
532
- else:
533
- total_gpus = min(
534
- total_gpus,
535
- os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1)
536
- which_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(",")
537
- which_gpus = [int(x) for x in which_gpus]
538
- else:
539
- which_gpus = list(range(0, total_gpus))
540
-
541
- return total_gpus, which_gpus
542
-
543
-
544
- def get_ngpus_vis(raise_if_exception=True):
545
- ngpus_vis1 = 0
546
-
547
- shell = False
548
- if shell:
549
- cmd = "nvidia-smi -L 2> /dev/null"
550
- else:
551
- cmd = ["nvidia-smi", "-L"]
552
-
553
- try:
554
- timeout = 5 * 3
555
- o = subprocess.check_output(cmd, shell=shell, timeout=timeout)
556
- lines = o.decode("utf-8").splitlines()
557
- ngpus_vis1 = 0
558
- for line in lines:
559
- if 'Failed to initialize NVML' not in line:
560
- ngpus_vis1 += 1
561
- except (FileNotFoundError, subprocess.CalledProcessError, OSError):
562
- # GPU systems might not have nvidia-smi, so can't fail
563
- pass
564
- except subprocess.TimeoutExpired as e:
565
- print('Failed get_ngpus_vis: %s' % str(e))
566
- if raise_if_exception:
567
- raise
568
-
569
- ngpus_vis1, which_gpus = cuda_vis_check(ngpus_vis1)
570
- return ngpus_vis1
571
-
572
-
573
- def get_mem_gpus(raise_if_exception=True, ngpus=None):
574
- totalmem_gpus1 = 0
575
- usedmem_gpus1 = 0
576
- freemem_gpus1 = 0
577
-
578
- if ngpus == 0:
579
- return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
580
-
581
- try:
582
- cmd = "nvidia-smi -q 2> /dev/null | grep -A 3 'FB Memory Usage'"
583
- o = subprocess.check_output(cmd, shell=True, timeout=15)
584
- lines = o.decode("utf-8").splitlines()
585
- for line in lines:
586
- if 'Total' in line:
587
- totalmem_gpus1 += int(line.split()[2]) * 1024 ** 2
588
- if 'Used' in line:
589
- usedmem_gpus1 += int(line.split()[2]) * 1024 ** 2
590
- if 'Free' in line:
591
- freemem_gpus1 += int(line.split()[2]) * 1024 ** 2
592
- except (FileNotFoundError, subprocess.CalledProcessError, OSError):
593
- # GPU systems might not have nvidia-smi, so can't fail
594
- pass
595
- except subprocess.TimeoutExpired as e:
596
- print('Failed get_mem_gpus: %s' % str(e))
597
- if raise_if_exception:
598
- raise
599
-
600
- return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
601
-
602
-
603
- class ForkContext(threading.local):
604
- """
605
- Set context for forking
606
- Ensures state is returned once done
607
- """
608
-
609
- def __init__(self, args=None, kwargs=None, forkdata_capable=True):
610
- """
611
- :param args:
612
- :param kwargs:
613
- :param forkdata_capable: whether fork is forkdata capable and will use copy-on-write forking of args/kwargs
614
- """
615
- self.forkdata_capable = forkdata_capable
616
- if self.forkdata_capable:
617
- self.has_args = args is not None
618
- self.has_kwargs = kwargs is not None
619
- forkdatacontext.args = args
620
- forkdatacontext.kwargs = kwargs
621
- else:
622
- self.has_args = False
623
- self.has_kwargs = False
624
-
625
- def __enter__(self):
626
- try:
627
- # flush all outputs so doesn't happen during fork -- don't print/log inside ForkContext contexts!
628
- sys.stdout.flush()
629
- sys.stderr.flush()
630
- except BaseException as e:
631
- # exit not called if exception, and don't want to leave forkdatacontext filled in that case
632
- print("ForkContext failure on enter: %s" % str(e))
633
- self.finally_act()
634
- raise
635
- return self
636
-
637
- def __exit__(self, exc_type, exc_value, exc_traceback):
638
- self.finally_act()
639
-
640
- def finally_act(self):
641
- """
642
- Done when exception hit or exit is reached in context
643
- first reset forkdatacontext as crucial to have reset even if later 2 calls fail
644
- :return: None
645
- """
646
- if self.forkdata_capable and (self.has_args or self.has_kwargs):
647
- forkdatacontext._reset()
648
-
649
-
650
- class _ForkDataContext(threading.local):
651
- def __init__(
652
- self,
653
- args=None,
654
- kwargs=None,
655
- ):
656
- """
657
- Global context for fork to carry data to subprocess instead of relying upon copy/pickle/serialization
658
-
659
- :param args: args
660
- :param kwargs: kwargs
661
- """
662
- assert isinstance(args, (tuple, type(None)))
663
- assert isinstance(kwargs, (dict, type(None)))
664
- self.__args = args
665
- self.__kwargs = kwargs
666
-
667
- @property
668
- def args(self) -> Tuple:
669
- """returns args"""
670
- return self.__args
671
-
672
- @args.setter
673
- def args(self, args):
674
- if self.__args is not None:
675
- raise AttributeError(
676
- "args cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
677
- )
678
-
679
- self.__args = args
680
-
681
- @property
682
- def kwargs(self) -> Dict:
683
- """returns kwargs"""
684
- return self.__kwargs
685
-
686
- @kwargs.setter
687
- def kwargs(self, kwargs):
688
- if self.__kwargs is not None:
689
- raise AttributeError(
690
- "kwargs cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
691
- )
692
-
693
- self.__kwargs = kwargs
694
-
695
- def _reset(self):
696
- """Reset fork arg-kwarg context to default values"""
697
- self.__args = None
698
- self.__kwargs = None
699
-
700
- def get_args_kwargs(self, func, args, kwargs) -> Tuple[Callable, Tuple, Dict]:
701
- if self.__args:
702
- args = self.__args[1:]
703
- if not func:
704
- assert len(self.__args) > 0, "if have no func, must have in args"
705
- func = self.__args[0] # should always be there
706
- if self.__kwargs:
707
- kwargs = self.__kwargs
708
- try:
709
- return func, args, kwargs
710
- finally:
711
- forkdatacontext._reset()
712
-
713
- @staticmethod
714
- def get_args_kwargs_for_traced_func(func, args, kwargs):
715
- """
716
- Return args/kwargs out of forkdatacontext when using copy-on-write way of passing args/kwargs
717
- :param func: actual function ran by _traced_func, which itself is directly what mppool treats as function
718
- :param args:
719
- :param kwargs:
720
- :return: func, args, kwargs from forkdatacontext if used, else originals
721
- """
722
- # first 3 lines are debug
723
- func_was_None = func is None
724
- args_was_None_or_empty = args is None or len(args) == 0
725
- kwargs_was_None_or_empty = kwargs is None or len(kwargs) == 0
726
-
727
- forkdatacontext_args_was_None = forkdatacontext.args is None
728
- forkdatacontext_kwargs_was_None = forkdatacontext.kwargs is None
729
- func, args, kwargs = forkdatacontext.get_args_kwargs(func, args, kwargs)
730
- using_forkdatacontext = func_was_None and func is not None # pulled func out of forkdatacontext.__args[0]
731
- assert forkdatacontext.args is None, "forkdatacontext.args should be None after get_args_kwargs"
732
- assert forkdatacontext.kwargs is None, "forkdatacontext.kwargs should be None after get_args_kwargs"
733
-
734
- proc_type = kwargs.get('proc_type', 'SUBPROCESS')
735
- if using_forkdatacontext:
736
- assert proc_type == "SUBPROCESS" or proc_type == "SUBPROCESS"
737
- if proc_type == "NORMAL":
738
- assert forkdatacontext_args_was_None, "if no fork, expect forkdatacontext.args None entering _traced_func"
739
- assert forkdatacontext_kwargs_was_None, "if no fork, expect forkdatacontext.kwargs None entering _traced_func"
740
- assert func is not None, "function should not be None, indicates original args[0] was None or args was None"
741
-
742
- return func, args, kwargs
743
-
744
-
745
- forkdatacontext = _ForkDataContext()
746
-
747
-
748
- def _traced_func(func, *args, **kwargs):
749
- func, args, kwargs = forkdatacontext.get_args_kwargs_for_traced_func(func, args, kwargs)
750
- return func(*args, **kwargs)
751
-
752
-
753
- def call_subprocess_onetask(func, args=None, kwargs=None):
754
- if isinstance(args, list):
755
- args = tuple(args)
756
- if args is None:
757
- args = ()
758
- if kwargs is None:
759
- kwargs = {}
760
- args = list(args)
761
- args = [func] + args
762
- args = tuple(args)
763
- with ForkContext(args=args, kwargs=kwargs):
764
- args = (None,)
765
- kwargs = {}
766
- with ProcessPoolExecutor(max_workers=1) as executor:
767
- future = executor.submit(_traced_func, *args, **kwargs)
768
- return future.result()
769
-
770
-
771
- class ProgressParallel(Parallel):
772
- def __init__(self, use_tqdm=True, total=None, *args, **kwargs):
773
- self._use_tqdm = use_tqdm
774
- self._total = total
775
- super().__init__(*args, **kwargs)
776
-
777
- def __call__(self, *args, **kwargs):
778
- with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar:
779
- return Parallel.__call__(self, *args, **kwargs)
780
-
781
- def print_progress(self):
782
- if self._total is None:
783
- self._pbar.total = self.n_dispatched_tasks
784
- self._pbar.n = self.n_completed_tasks
785
- self._pbar.refresh()
786
-
787
-
788
- def get_kwargs(func, exclude_names=None, **kwargs):
789
- func_names = list(inspect.signature(func).parameters)
790
- missing_kwargs = [x for x in func_names if x not in kwargs]
791
- if exclude_names:
792
- for k in exclude_names:
793
- if k in missing_kwargs:
794
- missing_kwargs.remove(k)
795
- if k in func_names:
796
- func_names.remove(k)
797
- assert not missing_kwargs, "Missing %s" % missing_kwargs
798
- kwargs = {k: v for k, v in kwargs.items() if k in func_names}
799
- return kwargs
800
-
801
-
802
- import pkg_resources
803
- have_faiss = False
804
-
805
- try:
806
- assert pkg_resources.get_distribution('faiss') is not None
807
- have_faiss = True
808
- except (pkg_resources.DistributionNotFound, AssertionError):
809
- pass
810
- try:
811
- assert pkg_resources.get_distribution('faiss_gpu') is not None
812
- have_faiss = True
813
- except (pkg_resources.DistributionNotFound, AssertionError):
814
- pass
815
- try:
816
- assert pkg_resources.get_distribution('faiss_cpu') is not None
817
- have_faiss = True
818
- except (pkg_resources.DistributionNotFound, AssertionError):
819
- pass
820
-
821
-
822
- def hash_file(file):
823
- try:
824
- import hashlib
825
-
826
- # BUF_SIZE is totally arbitrary, change for your app!
827
- BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
828
-
829
- md5 = hashlib.md5()
830
- #sha1 = hashlib.sha1()
831
-
832
- with open(file, 'rb') as f:
833
- while True:
834
- data = f.read(BUF_SIZE)
835
- if not data:
836
- break
837
- md5.update(data)
838
- #sha1.update(data)
839
- except BaseException as e:
840
- print("Cannot hash %s due to %s" % (file, str(e)))
841
- traceback.print_exc()
842
- md5 = None
843
- return md5.hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../utils.py