Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +0 -1
- LICENSE +201 -0
- README.md +3 -3
- app.py +1 -0
- generate.py +16 -0
- gradio_utils/__init__.py +0 -0
- gradio_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/css.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/grclient.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/prompt_form.cpython-310.pyc +0 -0
- gradio_utils/css.py +148 -0
- gradio_utils/grclient.py +82 -0
- gradio_utils/prompt_form.py +108 -0
- h2o-logo.svg +1 -0
- iterators/__init__.py +4 -0
- iterators/__pycache__/__init__.cpython-310.pyc +0 -0
- iterators/__pycache__/iterator_pipe.cpython-310.pyc +0 -0
- iterators/__pycache__/timeout_iterator.cpython-310.pyc +0 -0
- iterators/iterator_pipe.py +93 -0
- iterators/timeout_iterator.py +170 -0
- requirements.txt +192 -0
- src/LICENSE +201 -0
- src/__pycache__/enums.cpython-310.pyc +0 -0
- src/__pycache__/evaluate_params.cpython-310.pyc +0 -0
- src/__pycache__/gen.cpython-310.pyc +0 -0
- src/__pycache__/gen.cpython-312.pyc +0 -0
- src/__pycache__/gpt_langchain.cpython-310.pyc +0 -0
- src/__pycache__/loaders.cpython-310.pyc +0 -0
- src/__pycache__/prompter.cpython-310.pyc +0 -0
- src/__pycache__/stopping.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/client_test.py +484 -0
- src/create_data.py +1847 -0
- src/db_utils.py +54 -0
- src/enums.py +225 -0
- src/evaluate_params.py +71 -0
- src/gen.py +0 -0
- src/gpt4all_llm.py +403 -0
- src/gpt_langchain.py +0 -0
- src/gradio_runner.py +0 -0
- src/gradio_themes.py +260 -0
- src/gradio_utils/__init__.py +0 -0
- src/gradio_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/gradio_utils/__pycache__/css.cpython-310.pyc +0 -0
- src/gradio_utils/__pycache__/grclient.cpython-310.pyc +0 -0
- src/gradio_utils/__pycache__/prompt_form.cpython-310.pyc +0 -0
- src/gradio_utils/css.py +148 -0
- src/gradio_utils/grclient.py +82 -0
- src/gradio_utils/prompt_form.py +108 -0
- src/h2o-logo.svg +1 -0
.gitattributes
CHANGED
@@ -25,7 +25,6 @@
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
28 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 📊
|
4 |
colorFrom: yellow
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: test
|
|
|
3 |
colorFrom: yellow
|
4 |
colorTo: yellow
|
5 |
sdk: gradio
|
6 |
+
sdk_version: 3.41.2
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
+
license: apache-2.0
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
generate.py
|
generate.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
5 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
6 |
+
|
7 |
+
from src.gen import main
|
8 |
+
from src.utils import H2O_Fire
|
9 |
+
|
10 |
+
|
11 |
+
def entrypoint_main():
|
12 |
+
H2O_Fire(main)
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
entrypoint_main()
|
gradio_utils/__init__.py
ADDED
File without changes
|
gradio_utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (134 Bytes). View file
|
|
gradio_utils/__pycache__/css.cpython-310.pyc
ADDED
Binary file (3.65 kB). View file
|
|
gradio_utils/__pycache__/grclient.cpython-310.pyc
ADDED
Binary file (2.69 kB). View file
|
|
gradio_utils/__pycache__/prompt_form.cpython-310.pyc
ADDED
Binary file (2.96 kB). View file
|
|
gradio_utils/css.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_css(kwargs) -> str:
|
2 |
+
if kwargs['h2ocolors']:
|
3 |
+
css_code = """footer {visibility: hidden;}
|
4 |
+
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
5 |
+
body.dark{background:linear-gradient(#000000,#0d0d0d);}
|
6 |
+
"""
|
7 |
+
else:
|
8 |
+
css_code = """footer {visibility: hidden}"""
|
9 |
+
|
10 |
+
css_code += make_css_base()
|
11 |
+
return css_code
|
12 |
+
|
13 |
+
|
14 |
+
def make_css_base() -> str:
|
15 |
+
return """
|
16 |
+
#col_container {margin-left: auto; margin-right: auto; text-align: left;}
|
17 |
+
|
18 |
+
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
19 |
+
|
20 |
+
body.dark{#warning {background-color: #555555};}
|
21 |
+
|
22 |
+
#sidebar {
|
23 |
+
order: 1;
|
24 |
+
|
25 |
+
@media (max-width: 463px) {
|
26 |
+
order: 2;
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
#col-tabs {
|
31 |
+
order: 2;
|
32 |
+
|
33 |
+
@media (max-width: 463px) {
|
34 |
+
order: 1;
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
#small_btn {
|
39 |
+
margin: 0.6em 0em 0.55em 0;
|
40 |
+
max-width: 20em;
|
41 |
+
min-width: 5em !important;
|
42 |
+
height: 5em;
|
43 |
+
font-size: 14px !important;
|
44 |
+
}
|
45 |
+
|
46 |
+
#prompt-form {
|
47 |
+
border: 1px solid var(--primary-500) !important;
|
48 |
+
}
|
49 |
+
|
50 |
+
#prompt-form.block {
|
51 |
+
border-radius: var(--block-radius) !important;
|
52 |
+
}
|
53 |
+
|
54 |
+
#prompt-form textarea {
|
55 |
+
border: 1px solid rgb(209, 213, 219);
|
56 |
+
}
|
57 |
+
|
58 |
+
#prompt-form label > div {
|
59 |
+
margin-top: 4px;
|
60 |
+
}
|
61 |
+
|
62 |
+
button.primary:hover {
|
63 |
+
background-color: var(--primary-600) !important;
|
64 |
+
transition: .2s;
|
65 |
+
}
|
66 |
+
|
67 |
+
#prompt-form-area {
|
68 |
+
margin-bottom: 2.5rem;
|
69 |
+
}
|
70 |
+
.chatsmall chatbot {font-size: 10px !important}
|
71 |
+
|
72 |
+
.gradio-container {
|
73 |
+
max-width: none !important;
|
74 |
+
}
|
75 |
+
|
76 |
+
div.message {
|
77 |
+
padding: var(--text-lg) !important;
|
78 |
+
}
|
79 |
+
|
80 |
+
div.message.user > div.icon-button {
|
81 |
+
top: unset;
|
82 |
+
bottom: 0;
|
83 |
+
}
|
84 |
+
|
85 |
+
div.message.bot > div.icon-button {
|
86 |
+
top: unset;
|
87 |
+
bottom: 0;
|
88 |
+
}
|
89 |
+
|
90 |
+
#prompt-form-row {
|
91 |
+
position: relative;
|
92 |
+
}
|
93 |
+
|
94 |
+
#attach-button {
|
95 |
+
position: absolute;
|
96 |
+
top: 45px;
|
97 |
+
right: 20px;
|
98 |
+
|
99 |
+
display: flex;
|
100 |
+
justify-content: center;
|
101 |
+
border: 1px solid var(--primary-500) !important;
|
102 |
+
|
103 |
+
@media (max-width: 463px) {
|
104 |
+
width: 56px;
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
#attach-button > img {
|
109 |
+
margin-right: 0;
|
110 |
+
}
|
111 |
+
|
112 |
+
#prompt-form > label > textarea {
|
113 |
+
padding-right: 104px;
|
114 |
+
|
115 |
+
@media (max-width: 463px) {
|
116 |
+
min-height: 94px;
|
117 |
+
padding-right: 70px;
|
118 |
+
}
|
119 |
+
}
|
120 |
+
|
121 |
+
#visible-models > label > div.wrap > div.wrap-inner > div.secondary-wrap > div.remove-all {
|
122 |
+
display: none !important;
|
123 |
+
}
|
124 |
+
|
125 |
+
#visible-models > label > div.wrap > div.wrap-inner > div.token {
|
126 |
+
display: none !important;
|
127 |
+
}
|
128 |
+
|
129 |
+
#visible-models > label > div.wrap > div.wrap-inner > div.secondary-wrap::before {
|
130 |
+
content: "Select";
|
131 |
+
padding: 0 4px;
|
132 |
+
margin-right: 2px;
|
133 |
+
}
|
134 |
+
|
135 |
+
#langchain_agents > label > div.wrap > div.wrap-inner > div.secondary-wrap > div.remove-all {
|
136 |
+
display: none !important;
|
137 |
+
}
|
138 |
+
|
139 |
+
#langchain_agents > label > div.wrap > div.wrap-inner > div.token {
|
140 |
+
display: none !important;
|
141 |
+
}
|
142 |
+
|
143 |
+
#langchain_agents > label > div.wrap > div.wrap-inner > div.secondary-wrap::before {
|
144 |
+
content: "Select";
|
145 |
+
padding: 0 4px;
|
146 |
+
margin-right: 2px;
|
147 |
+
}
|
148 |
+
"""
|
gradio_utils/grclient.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from typing import Callable
|
3 |
+
import os
|
4 |
+
|
5 |
+
from gradio_client.client import Job
|
6 |
+
|
7 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
8 |
+
|
9 |
+
from gradio_client import Client
|
10 |
+
|
11 |
+
|
12 |
+
class GradioClient(Client):
|
13 |
+
"""
|
14 |
+
Parent class of gradio client
|
15 |
+
To handle automatically refreshing client if detect gradio server changed
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, *args, **kwargs):
|
19 |
+
self.args = args
|
20 |
+
self.kwargs = kwargs
|
21 |
+
super().__init__(*args, **kwargs)
|
22 |
+
self.server_hash = self.get_server_hash()
|
23 |
+
|
24 |
+
def get_server_hash(self):
|
25 |
+
"""
|
26 |
+
Get server hash using super without any refresh action triggered
|
27 |
+
Returns: git hash of gradio server
|
28 |
+
"""
|
29 |
+
return super().submit(api_name='/system_hash').result()
|
30 |
+
|
31 |
+
def refresh_client_if_should(self):
|
32 |
+
# get current hash in order to update api_name -> fn_index map in case gradio server changed
|
33 |
+
# FIXME: Could add cli api as hash
|
34 |
+
server_hash = self.get_server_hash()
|
35 |
+
if self.server_hash != server_hash:
|
36 |
+
self.refresh_client()
|
37 |
+
self.server_hash = server_hash
|
38 |
+
else:
|
39 |
+
self.reset_session()
|
40 |
+
|
41 |
+
def refresh_client(self):
|
42 |
+
"""
|
43 |
+
Ensure every client call is independent
|
44 |
+
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
|
45 |
+
Returns:
|
46 |
+
"""
|
47 |
+
# need session hash to be new every time, to avoid "generator already executing"
|
48 |
+
self.reset_session()
|
49 |
+
|
50 |
+
client = Client(*self.args, **self.kwargs)
|
51 |
+
for k, v in client.__dict__.items():
|
52 |
+
setattr(self, k, v)
|
53 |
+
|
54 |
+
def submit(
|
55 |
+
self,
|
56 |
+
*args,
|
57 |
+
api_name: str | None = None,
|
58 |
+
fn_index: int | None = None,
|
59 |
+
result_callbacks: Callable | list[Callable] | None = None,
|
60 |
+
) -> Job:
|
61 |
+
# Note predict calls submit
|
62 |
+
try:
|
63 |
+
self.refresh_client_if_should()
|
64 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
65 |
+
except Exception as e:
|
66 |
+
print("Hit e=%s" % str(e), flush=True)
|
67 |
+
# force reconfig in case only that
|
68 |
+
self.refresh_client()
|
69 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
70 |
+
|
71 |
+
# see if immediately failed
|
72 |
+
e = job.future._exception
|
73 |
+
if e is not None:
|
74 |
+
print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True)
|
75 |
+
# force reconfig in case only that
|
76 |
+
self.refresh_client()
|
77 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
78 |
+
e2 = job.future._exception
|
79 |
+
if e2 is not None:
|
80 |
+
print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True)
|
81 |
+
|
82 |
+
return job
|
gradio_utils/prompt_form.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def make_chatbots(output_label0, output_label0_model2, **kwargs):
|
8 |
+
visible_models = kwargs['visible_models']
|
9 |
+
all_models = kwargs['all_models']
|
10 |
+
|
11 |
+
text_outputs = []
|
12 |
+
chat_kwargs = []
|
13 |
+
for model_state_locki, model_state_lock in enumerate(kwargs['model_states']):
|
14 |
+
if os.environ.get('DEBUG_MODEL_LOCK'):
|
15 |
+
model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
|
16 |
+
else:
|
17 |
+
model_name = model_state_lock["base_model"]
|
18 |
+
output_label = f'h2oGPT [{model_name}]'
|
19 |
+
min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
|
20 |
+
chat_kwargs.append(dict(label=output_label, elem_classes='chatsmall',
|
21 |
+
height=kwargs['height'] or 400, min_width=min_width,
|
22 |
+
show_copy_button=kwargs['show_copy_button'],
|
23 |
+
visible=kwargs['model_lock'] and (visible_models is None or
|
24 |
+
model_state_locki in visible_models or
|
25 |
+
all_models[model_state_locki] in visible_models
|
26 |
+
)))
|
27 |
+
|
28 |
+
# base view on initial visible choice
|
29 |
+
if visible_models:
|
30 |
+
len_visible = len(visible_models)
|
31 |
+
else:
|
32 |
+
len_visible = len(kwargs['model_states'])
|
33 |
+
if kwargs['model_lock_columns'] == -1:
|
34 |
+
kwargs['model_lock_columns'] = len_visible
|
35 |
+
if kwargs['model_lock_columns'] is None:
|
36 |
+
kwargs['model_lock_columns'] = 3
|
37 |
+
|
38 |
+
ncols = kwargs['model_lock_columns']
|
39 |
+
if kwargs['model_states'] == 0:
|
40 |
+
nrows = 0
|
41 |
+
else:
|
42 |
+
nrows = math.ceil(len_visible / kwargs['model_lock_columns'])
|
43 |
+
|
44 |
+
if kwargs['model_lock_columns'] == 0:
|
45 |
+
# not using model_lock
|
46 |
+
pass
|
47 |
+
elif nrows <= 1:
|
48 |
+
with gr.Row():
|
49 |
+
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
50 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
51 |
+
elif nrows == kwargs['model_states']:
|
52 |
+
with gr.Row():
|
53 |
+
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
54 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
55 |
+
elif nrows == 2:
|
56 |
+
with gr.Row():
|
57 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
58 |
+
if mii >= len_visible / 2:
|
59 |
+
continue
|
60 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
61 |
+
with gr.Row():
|
62 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
63 |
+
if mii < len_visible / 2:
|
64 |
+
continue
|
65 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
66 |
+
elif nrows == 3:
|
67 |
+
with gr.Row():
|
68 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
69 |
+
if mii >= 1 * len_visible / 3:
|
70 |
+
continue
|
71 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
72 |
+
with gr.Row():
|
73 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
74 |
+
if mii < 1 * len_visible / 3 or mii >= 2 * len_visible / 3:
|
75 |
+
continue
|
76 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
77 |
+
with gr.Row():
|
78 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
79 |
+
if mii < 2 * len_visible / 3:
|
80 |
+
continue
|
81 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
82 |
+
elif nrows >= 4:
|
83 |
+
with gr.Row():
|
84 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
85 |
+
if mii >= 1 * len_visible / 4:
|
86 |
+
continue
|
87 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
88 |
+
with gr.Row():
|
89 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
90 |
+
if mii < 1 * len_visible / 4 or mii >= 2 * len_visible / 4:
|
91 |
+
continue
|
92 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
93 |
+
with gr.Row():
|
94 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
95 |
+
if mii < 2 * len_visible / 4 or mii >= 3 * len_visible / 4:
|
96 |
+
continue
|
97 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
98 |
+
with gr.Row():
|
99 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
100 |
+
if mii < 3 * len_visible / 4:
|
101 |
+
continue
|
102 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
103 |
+
|
104 |
+
with gr.Row():
|
105 |
+
text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
|
106 |
+
text_output2 = gr.Chatbot(label=output_label0_model2,
|
107 |
+
visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
|
108 |
+
return text_output, text_output2, text_outputs
|
h2o-logo.svg
ADDED
iterators/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .timeout_iterator import TimeoutIterator, AsyncTimeoutIterator
|
2 |
+
from .iterator_pipe import IteratorPipe, AsyncIteratorPipe
|
3 |
+
|
4 |
+
__all__ = ["TimeoutIterator", "AsyncTimeoutIterator", "IteratorPipe", "AsyncIteratorPipe"]
|
iterators/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (337 Bytes). View file
|
|
iterators/__pycache__/iterator_pipe.cpython-310.pyc
ADDED
Binary file (2.71 kB). View file
|
|
iterators/__pycache__/timeout_iterator.cpython-310.pyc
ADDED
Binary file (5.63 kB). View file
|
|
iterators/iterator_pipe.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue
|
2 |
+
import asyncio
|
3 |
+
|
4 |
+
|
5 |
+
class IteratorPipe:
|
6 |
+
"""
|
7 |
+
Iterator Pipe creates an iterator that can be fed in data from another block of code or thread of execution
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, sentinel=object()):
|
11 |
+
self._q = queue.Queue()
|
12 |
+
self._sentinel = sentinel
|
13 |
+
self._sentinel_pushed = False
|
14 |
+
self._closed = False
|
15 |
+
|
16 |
+
def __iter__(self):
|
17 |
+
return self
|
18 |
+
|
19 |
+
def __next__(self):
|
20 |
+
if self._closed:
|
21 |
+
raise StopIteration
|
22 |
+
|
23 |
+
data = self._q.get(block=True)
|
24 |
+
if data is self._sentinel:
|
25 |
+
self._closed = True
|
26 |
+
raise StopIteration
|
27 |
+
|
28 |
+
return data
|
29 |
+
|
30 |
+
def put(self, data) -> bool:
|
31 |
+
"""
|
32 |
+
Pushes next item to Iterator and returns True
|
33 |
+
If iterator has been closed via close(), doesn't push anything and returns False
|
34 |
+
"""
|
35 |
+
if self._sentinel_pushed:
|
36 |
+
return False
|
37 |
+
|
38 |
+
self._q.put(data)
|
39 |
+
return True
|
40 |
+
|
41 |
+
def close(self):
|
42 |
+
"""
|
43 |
+
Close is idempotent. Calling close multiple times is safe
|
44 |
+
Iterator will raise StopIteration only after all elements pushed before close have been iterated
|
45 |
+
"""
|
46 |
+
# make close idempotent
|
47 |
+
if not self._sentinel_pushed:
|
48 |
+
self._sentinel_pushed = True
|
49 |
+
self._q.put(self._sentinel)
|
50 |
+
|
51 |
+
|
52 |
+
class AsyncIteratorPipe:
|
53 |
+
|
54 |
+
def __init__(self, sentinel=object()):
|
55 |
+
self._q = asyncio.Queue()
|
56 |
+
self._sentinel = sentinel
|
57 |
+
self._sentinel_pushed = False
|
58 |
+
self._closed = False
|
59 |
+
|
60 |
+
def __aiter__(self):
|
61 |
+
return self
|
62 |
+
|
63 |
+
async def __anext__(self):
|
64 |
+
if self._closed:
|
65 |
+
raise StopAsyncIteration
|
66 |
+
|
67 |
+
data = await self._q.get()
|
68 |
+
if data is self._sentinel:
|
69 |
+
self._closed = True
|
70 |
+
raise StopAsyncIteration
|
71 |
+
|
72 |
+
return data
|
73 |
+
|
74 |
+
async def put(self, data) -> bool:
|
75 |
+
"""
|
76 |
+
Pushes next item to Iterator and returns True
|
77 |
+
If iterator has been closed via close(), doesn't push anything and returns False
|
78 |
+
"""
|
79 |
+
if self._sentinel_pushed:
|
80 |
+
return False
|
81 |
+
|
82 |
+
await self._q.put(data)
|
83 |
+
return True
|
84 |
+
|
85 |
+
async def close(self):
|
86 |
+
"""
|
87 |
+
Close is idempotent. Calling close multiple times is safe
|
88 |
+
Iterator will raise StopIteration only after all elements pushed before close have been iterated
|
89 |
+
"""
|
90 |
+
# make close idempotent
|
91 |
+
if not self._sentinel_pushed:
|
92 |
+
self._sentinel_pushed = True
|
93 |
+
await self._q.put(self._sentinel)
|
iterators/timeout_iterator.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue
|
2 |
+
import asyncio
|
3 |
+
import threading
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
|
7 |
+
class TimeoutIterator:
|
8 |
+
"""
|
9 |
+
Wrapper class to add timeout feature to synchronous iterators
|
10 |
+
- timeout: timeout for next(). Default=ZERO_TIMEOUT i.e. no timeout or blocking calls to next. Updated using set_timeout()
|
11 |
+
- sentinel: the object returned by iterator when timeout happens
|
12 |
+
- reset_on_next: if set to True, timeout is reset to the value of ZERO_TIMEOUT on each iteration
|
13 |
+
|
14 |
+
TimeoutIterator uses a thread internally.
|
15 |
+
The thread stops once the iterator exhausts or raises an exception during iteration.
|
16 |
+
|
17 |
+
Any exceptions raised within the wrapped iterator are propagated as it is.
|
18 |
+
Exception is raised when all elements generated by the actual iterator before exception have been consumed
|
19 |
+
Timeout can be set dynamically before going for iteration
|
20 |
+
"""
|
21 |
+
ZERO_TIMEOUT = 0.0
|
22 |
+
|
23 |
+
def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False, raise_on_exception=True):
|
24 |
+
self._iterator = iterator
|
25 |
+
self._timeout = timeout
|
26 |
+
self._sentinel = sentinel
|
27 |
+
self._reset_on_next = reset_on_next
|
28 |
+
self._raise_on_exception = raise_on_exception
|
29 |
+
|
30 |
+
self._interrupt = False
|
31 |
+
self._done = False
|
32 |
+
self._buffer = queue.Queue()
|
33 |
+
self._thread = threading.Thread(target=self.__lookahead)
|
34 |
+
self._thread.start()
|
35 |
+
|
36 |
+
def get_sentinel(self):
|
37 |
+
return self._sentinel
|
38 |
+
|
39 |
+
def set_reset_on_next(self, reset_on_next):
|
40 |
+
self._reset_on_next = reset_on_next
|
41 |
+
|
42 |
+
def set_timeout(self, timeout: float):
|
43 |
+
"""
|
44 |
+
Set timeout for next iteration
|
45 |
+
"""
|
46 |
+
self._timeout = timeout
|
47 |
+
|
48 |
+
def interrupt(self):
|
49 |
+
"""
|
50 |
+
interrupt and stop the underlying thread.
|
51 |
+
the thread actually dies only after interrupt has been set and
|
52 |
+
the underlying iterator yields a value after that.
|
53 |
+
"""
|
54 |
+
self._interrupt = True
|
55 |
+
|
56 |
+
def __iter__(self):
|
57 |
+
return self
|
58 |
+
|
59 |
+
def __next__(self):
|
60 |
+
"""
|
61 |
+
yield the result from iterator
|
62 |
+
if timeout > 0:
|
63 |
+
yield data if available.
|
64 |
+
otherwise yield sentinal
|
65 |
+
"""
|
66 |
+
if self._done:
|
67 |
+
raise StopIteration
|
68 |
+
|
69 |
+
data = self._sentinel
|
70 |
+
try:
|
71 |
+
if self._timeout > self.ZERO_TIMEOUT:
|
72 |
+
data = self._buffer.get(timeout=self._timeout)
|
73 |
+
else:
|
74 |
+
data = self._buffer.get()
|
75 |
+
except queue.Empty:
|
76 |
+
pass
|
77 |
+
finally:
|
78 |
+
# see if timeout needs to be reset
|
79 |
+
if self._reset_on_next:
|
80 |
+
self._timeout = self.ZERO_TIMEOUT
|
81 |
+
|
82 |
+
# propagate any exceptions including StopIteration
|
83 |
+
if isinstance(data, BaseException):
|
84 |
+
self._done = True
|
85 |
+
if isinstance(data, StopIteration):
|
86 |
+
raise data
|
87 |
+
ex = ''.join(traceback.format_tb(data.__traceback__))
|
88 |
+
print("Generation Failed: %s %s" % (str(data), str(ex)), flush=True)
|
89 |
+
if self._raise_on_exception:
|
90 |
+
raise data
|
91 |
+
else:
|
92 |
+
return data
|
93 |
+
|
94 |
+
return data
|
95 |
+
|
96 |
+
def __lookahead(self):
|
97 |
+
try:
|
98 |
+
while True:
|
99 |
+
self._buffer.put(next(self._iterator))
|
100 |
+
if self._interrupt:
|
101 |
+
raise StopIteration()
|
102 |
+
except BaseException as e:
|
103 |
+
self._buffer.put(e)
|
104 |
+
|
105 |
+
|
106 |
+
class AsyncTimeoutIterator:
|
107 |
+
"""
|
108 |
+
Async version of TimeoutIterator. See method documentation of TimeoutIterator
|
109 |
+
"""
|
110 |
+
ZERO_TIMEOUT = 0.0
|
111 |
+
|
112 |
+
def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False):
|
113 |
+
self._iterator = iterator
|
114 |
+
self._timeout = timeout
|
115 |
+
self._sentinel = sentinel
|
116 |
+
self._reset_on_next = reset_on_next
|
117 |
+
|
118 |
+
self._interrupt = False
|
119 |
+
self._done = False
|
120 |
+
self._buffer = asyncio.Queue()
|
121 |
+
self._task = asyncio.get_event_loop().create_task(self.__lookahead())
|
122 |
+
|
123 |
+
def get_sentinel(self):
|
124 |
+
return self._sentinel
|
125 |
+
|
126 |
+
def set_reset_on_next(self, reset_on_next):
|
127 |
+
self._reset_on_next = reset_on_next
|
128 |
+
|
129 |
+
def set_timeout(self, timeout: float):
|
130 |
+
self._timeout = timeout
|
131 |
+
|
132 |
+
def interrupt(self):
|
133 |
+
self._interrupt = True
|
134 |
+
|
135 |
+
def __aiter__(self):
|
136 |
+
return self
|
137 |
+
|
138 |
+
async def __anext__(self):
|
139 |
+
if self._done:
|
140 |
+
raise StopAsyncIteration
|
141 |
+
|
142 |
+
data = self._sentinel
|
143 |
+
try:
|
144 |
+
if self._timeout > self.ZERO_TIMEOUT:
|
145 |
+
data = await asyncio.wait_for(self._buffer.get(), self._timeout)
|
146 |
+
else:
|
147 |
+
data = await self._buffer.get()
|
148 |
+
except asyncio.TimeoutError:
|
149 |
+
pass
|
150 |
+
finally:
|
151 |
+
# see if timeout needs to be reset
|
152 |
+
if self._reset_on_next:
|
153 |
+
self._timeout = self.ZERO_TIMEOUT
|
154 |
+
|
155 |
+
# propagate any exceptions including StopIteration
|
156 |
+
if isinstance(data, BaseException):
|
157 |
+
self._done = True
|
158 |
+
raise data
|
159 |
+
|
160 |
+
return data
|
161 |
+
|
162 |
+
async def __lookahead(self):
|
163 |
+
try:
|
164 |
+
while True:
|
165 |
+
data = await self._iterator.__anext__()
|
166 |
+
await self._buffer.put(data)
|
167 |
+
if self._interrupt:
|
168 |
+
raise StopAsyncIteration()
|
169 |
+
except BaseException as e:
|
170 |
+
await self._buffer.put(e)
|
requirements.txt
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# for generate (gradio server) and finetune
|
2 |
+
datasets==2.13.0
|
3 |
+
sentencepiece==0.1.99
|
4 |
+
gradio==3.41.2
|
5 |
+
huggingface_hub==0.16.4
|
6 |
+
appdirs==1.4.4
|
7 |
+
fire==0.5.0
|
8 |
+
docutils==0.20.1
|
9 |
+
torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64"
|
10 |
+
evaluate==0.4.0
|
11 |
+
rouge_score==0.1.2
|
12 |
+
sacrebleu==2.3.1
|
13 |
+
scikit-learn==1.2.2
|
14 |
+
# optional (need to uncomment code in gradio_runner.py for import of better_profanity)
|
15 |
+
# alt-profanity-check==1.2.2
|
16 |
+
# better-profanity==0.7.0
|
17 |
+
numpy==1.24.3
|
18 |
+
pandas==2.0.2
|
19 |
+
matplotlib==3.7.1
|
20 |
+
loralib==0.1.1
|
21 |
+
bitsandbytes==0.41.1
|
22 |
+
accelerate==0.22.0
|
23 |
+
peft==0.5.0
|
24 |
+
transformers==4.33.1
|
25 |
+
tokenizers==0.13.3
|
26 |
+
APScheduler==3.10.1
|
27 |
+
|
28 |
+
# optional for generate
|
29 |
+
pynvml==11.5.0
|
30 |
+
psutil==5.9.5
|
31 |
+
boto3==1.26.101
|
32 |
+
botocore==1.29.101
|
33 |
+
|
34 |
+
# optional for finetune
|
35 |
+
tensorboard==2.13.0
|
36 |
+
neptune==1.2.0
|
37 |
+
|
38 |
+
# for gradio client
|
39 |
+
gradio_client==0.5.0
|
40 |
+
beautifulsoup4==4.12.2
|
41 |
+
markdown==3.4.3
|
42 |
+
|
43 |
+
# data and testing
|
44 |
+
pytest==7.2.2
|
45 |
+
pytest-xdist==3.2.1
|
46 |
+
nltk==3.8.1
|
47 |
+
textstat==0.7.3
|
48 |
+
# pandoc==2.3
|
49 |
+
pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
|
50 |
+
pypandoc_binary==1.11; platform_machine == "x86_64"
|
51 |
+
pypandoc_binary==1.11; sys_platform == "win32"
|
52 |
+
python-magic-bin==0.4.14; sys_platform == "win32"
|
53 |
+
openpyxl==3.1.2
|
54 |
+
lm_dataformat==0.0.20
|
55 |
+
bioc==2.0
|
56 |
+
|
57 |
+
# falcon
|
58 |
+
einops==0.6.1
|
59 |
+
instructorembedding==1.0.1
|
60 |
+
|
61 |
+
# for gpt4all .env file, but avoid worrying about imports
|
62 |
+
python-dotenv==1.0.0
|
63 |
+
|
64 |
+
text-generation==0.6.0
|
65 |
+
# for tokenization when don't have HF tokenizer
|
66 |
+
tiktoken==0.4.0
|
67 |
+
|
68 |
+
requests>=2.31.0
|
69 |
+
urllib3>=1.26.16
|
70 |
+
filelock>=3.12.2
|
71 |
+
joblib>=1.3.1
|
72 |
+
tqdm>=4.65.0
|
73 |
+
tabulate>=0.9.0
|
74 |
+
packaging>=23.1
|
75 |
+
# optional for chat with PDF
|
76 |
+
langchain==0.0.300
|
77 |
+
pypdf==3.14.0
|
78 |
+
# avoid textract, requires old six
|
79 |
+
#textract==1.6.5
|
80 |
+
pypdfium2==4.19.0
|
81 |
+
|
82 |
+
# for HF embeddings
|
83 |
+
sentence_transformers==2.2.2
|
84 |
+
|
85 |
+
# optional: for OpenAI endpoint or embeddings (requires key)
|
86 |
+
openai==0.27.8
|
87 |
+
replicate==0.10.0
|
88 |
+
|
89 |
+
# local vector db
|
90 |
+
chromadb==0.4.10
|
91 |
+
|
92 |
+
# chroma migration
|
93 |
+
chroma-migrate==0.0.7
|
94 |
+
duckdb==0.7.1
|
95 |
+
https://h2o-release.s3.amazonaws.com/h2ogpt/chromamigdb-0.3.25-py3-none-any.whl
|
96 |
+
https://h2o-release.s3.amazonaws.com/h2ogpt/hnswmiglib-0.7.0.tgz
|
97 |
+
|
98 |
+
# server vector db
|
99 |
+
#pymilvus==2.2.8
|
100 |
+
|
101 |
+
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
102 |
+
# unstructured==0.8.1
|
103 |
+
|
104 |
+
# strong support for images
|
105 |
+
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
106 |
+
unstructured[local-inference]==0.9.0
|
107 |
+
#pdf2image==1.16.3
|
108 |
+
#pytesseract==0.3.10
|
109 |
+
pillow==9.5.0
|
110 |
+
posthog==3.0.1
|
111 |
+
|
112 |
+
pdfminer.six==20221105
|
113 |
+
urllib3
|
114 |
+
requests_file
|
115 |
+
|
116 |
+
#pdf2image==1.16.3
|
117 |
+
#pytesseract==0.3.10
|
118 |
+
tabulate==0.9.0
|
119 |
+
# FYI pandoc already part of requirements.txt
|
120 |
+
|
121 |
+
# JSONLoader, but makes some trouble for some users
|
122 |
+
# TRY: apt-get install autoconf libtool
|
123 |
+
# unclear what happens on windows/mac for now
|
124 |
+
jq==1.4.1; platform_machine == "x86_64"
|
125 |
+
|
126 |
+
# to check licenses
|
127 |
+
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
128 |
+
pip-licenses==4.3.0
|
129 |
+
|
130 |
+
# weaviate vector db
|
131 |
+
weaviate-client==3.22.1
|
132 |
+
# optional for chat with PDF
|
133 |
+
langchain==0.0.300
|
134 |
+
pypdf==3.14.0
|
135 |
+
# avoid textract, requires old six
|
136 |
+
#textract==1.6.5
|
137 |
+
pypdfium2==4.19.0
|
138 |
+
|
139 |
+
# for HF embeddings
|
140 |
+
sentence_transformers==2.2.2
|
141 |
+
|
142 |
+
# optional: for OpenAI endpoint or embeddings (requires key)
|
143 |
+
openai==0.27.8
|
144 |
+
replicate==0.10.0
|
145 |
+
|
146 |
+
# local vector db
|
147 |
+
chromadb==0.4.10
|
148 |
+
|
149 |
+
# chroma migration
|
150 |
+
chroma-migrate==0.0.7
|
151 |
+
duckdb==0.7.1
|
152 |
+
https://h2o-release.s3.amazonaws.com/h2ogpt/chromamigdb-0.3.25-py3-none-any.whl
|
153 |
+
https://h2o-release.s3.amazonaws.com/h2ogpt/hnswmiglib-0.7.0.tgz
|
154 |
+
|
155 |
+
# server vector db
|
156 |
+
#pymilvus==2.2.8
|
157 |
+
|
158 |
+
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
159 |
+
# unstructured==0.8.1
|
160 |
+
|
161 |
+
# strong support for images
|
162 |
+
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
163 |
+
unstructured[local-inference]==0.9.0
|
164 |
+
#pdf2image==1.16.3
|
165 |
+
#pytesseract==0.3.10
|
166 |
+
pillow==9.5.0
|
167 |
+
posthog==3.0.1
|
168 |
+
|
169 |
+
pdfminer.six==20221105
|
170 |
+
urllib3
|
171 |
+
requests_file
|
172 |
+
|
173 |
+
#pdf2image==1.16.3
|
174 |
+
#pytesseract==0.3.10
|
175 |
+
tabulate==0.9.0
|
176 |
+
# FYI pandoc already part of requirements.txt
|
177 |
+
|
178 |
+
# JSONLoader, but makes some trouble for some users
|
179 |
+
# TRY: apt-get install autoconf libtool
|
180 |
+
# unclear what happens on windows/mac for now
|
181 |
+
jq==1.4.1; platform_machine == "x86_64"
|
182 |
+
|
183 |
+
# to check licenses
|
184 |
+
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
185 |
+
pip-licenses==4.3.0
|
186 |
+
|
187 |
+
# weaviate vector db
|
188 |
+
weaviate-client==3.22.1
|
189 |
+
faiss-gpu==1.7.2
|
190 |
+
arxiv==1.4.8
|
191 |
+
pymupdf==1.23.1 # AGPL license
|
192 |
+
# extract-msg==0.41.1 # GPL3
|
src/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
src/__pycache__/enums.cpython-310.pyc
ADDED
Binary file (6.52 kB). View file
|
|
src/__pycache__/evaluate_params.cpython-310.pyc
ADDED
Binary file (1.32 kB). View file
|
|
src/__pycache__/gen.cpython-310.pyc
ADDED
Binary file (102 kB). View file
|
|
src/__pycache__/gen.cpython-312.pyc
ADDED
Binary file (148 kB). View file
|
|
src/__pycache__/gpt_langchain.cpython-310.pyc
ADDED
Binary file (122 kB). View file
|
|
src/__pycache__/loaders.cpython-310.pyc
ADDED
Binary file (3.38 kB). View file
|
|
src/__pycache__/prompter.cpython-310.pyc
ADDED
Binary file (25.7 kB). View file
|
|
src/__pycache__/stopping.cpython-310.pyc
ADDED
Binary file (5.18 kB). View file
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (40.9 kB). View file
|
|
src/client_test.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Client test.
|
3 |
+
|
4 |
+
Run server:
|
5 |
+
|
6 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
7 |
+
|
8 |
+
NOTE: For private models, add --use-auth_token=True
|
9 |
+
|
10 |
+
NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
|
11 |
+
Currently, this will force model to be on a single GPU.
|
12 |
+
|
13 |
+
Then run this client as:
|
14 |
+
|
15 |
+
python src/client_test.py
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
For HF spaces:
|
20 |
+
|
21 |
+
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py
|
22 |
+
|
23 |
+
Result:
|
24 |
+
|
25 |
+
Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
|
26 |
+
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''}
|
27 |
+
|
28 |
+
|
29 |
+
For demo:
|
30 |
+
|
31 |
+
HOST="https://gpt.h2o.ai" python src/client_test.py
|
32 |
+
|
33 |
+
Result:
|
34 |
+
|
35 |
+
Loaded as API: https://gpt.h2o.ai ✔
|
36 |
+
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''}
|
37 |
+
|
38 |
+
NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict:
|
39 |
+
|
40 |
+
{'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''}
|
41 |
+
|
42 |
+
|
43 |
+
"""
|
44 |
+
import ast
|
45 |
+
import time
|
46 |
+
import os
|
47 |
+
import markdown # pip install markdown
|
48 |
+
import pytest
|
49 |
+
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
50 |
+
|
51 |
+
try:
|
52 |
+
from enums import DocumentSubset, LangChainAction
|
53 |
+
except:
|
54 |
+
from src.enums import DocumentSubset, LangChainAction
|
55 |
+
|
56 |
+
from tests.utils import get_inf_server
|
57 |
+
|
58 |
+
debug = False
|
59 |
+
|
60 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
61 |
+
|
62 |
+
|
63 |
+
def get_client(serialize=True):
|
64 |
+
from gradio_client import Client
|
65 |
+
|
66 |
+
client = Client(get_inf_server(), serialize=serialize)
|
67 |
+
if debug:
|
68 |
+
print(client.view_api(all_endpoints=True))
|
69 |
+
return client
|
70 |
+
|
71 |
+
|
72 |
+
def get_args(prompt, prompt_type=None, chat=False, stream_output=False,
|
73 |
+
max_new_tokens=50,
|
74 |
+
top_k_docs=3,
|
75 |
+
langchain_mode='Disabled',
|
76 |
+
add_chat_history_to_context=True,
|
77 |
+
langchain_action=LangChainAction.QUERY.value,
|
78 |
+
langchain_agents=[],
|
79 |
+
prompt_dict=None,
|
80 |
+
version=None,
|
81 |
+
h2ogpt_key=None,
|
82 |
+
visible_models=None,
|
83 |
+
system_prompt='', # default of no system prompt tiggered by empty string
|
84 |
+
add_search_to_context=False,
|
85 |
+
chat_conversation=None,
|
86 |
+
text_context_list=None,
|
87 |
+
):
|
88 |
+
from collections import OrderedDict
|
89 |
+
kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
|
90 |
+
iinput='', # only for chat=True
|
91 |
+
context='',
|
92 |
+
# streaming output is supported, loops over and outputs each generation in streaming mode
|
93 |
+
# but leave stream_output=False for simple input/output mode
|
94 |
+
stream_output=stream_output,
|
95 |
+
prompt_type=prompt_type,
|
96 |
+
prompt_dict=prompt_dict,
|
97 |
+
temperature=0.1,
|
98 |
+
top_p=0.75,
|
99 |
+
top_k=40,
|
100 |
+
num_beams=1,
|
101 |
+
max_new_tokens=max_new_tokens,
|
102 |
+
min_new_tokens=0,
|
103 |
+
early_stopping=False,
|
104 |
+
max_time=20,
|
105 |
+
repetition_penalty=1.0,
|
106 |
+
num_return_sequences=1,
|
107 |
+
do_sample=True,
|
108 |
+
chat=chat,
|
109 |
+
instruction_nochat=prompt if not chat else '',
|
110 |
+
iinput_nochat='', # only for chat=False
|
111 |
+
langchain_mode=langchain_mode,
|
112 |
+
add_chat_history_to_context=add_chat_history_to_context,
|
113 |
+
langchain_action=langchain_action,
|
114 |
+
langchain_agents=langchain_agents,
|
115 |
+
top_k_docs=top_k_docs,
|
116 |
+
chunk=True,
|
117 |
+
chunk_size=512,
|
118 |
+
document_subset=DocumentSubset.Relevant.name,
|
119 |
+
document_choice=[],
|
120 |
+
pre_prompt_query=None,
|
121 |
+
prompt_query=None,
|
122 |
+
pre_prompt_summary=None,
|
123 |
+
prompt_summary=None,
|
124 |
+
system_prompt=system_prompt,
|
125 |
+
image_loaders=None,
|
126 |
+
pdf_loaders=None,
|
127 |
+
url_loaders=None,
|
128 |
+
jq_schema=None,
|
129 |
+
visible_models=visible_models,
|
130 |
+
h2ogpt_key=h2ogpt_key,
|
131 |
+
add_search_to_context=add_search_to_context,
|
132 |
+
chat_conversation=chat_conversation,
|
133 |
+
text_context_list=text_context_list,
|
134 |
+
docs_ordering_type=None,
|
135 |
+
min_max_new_tokens=None,
|
136 |
+
)
|
137 |
+
diff = 0
|
138 |
+
if version is None:
|
139 |
+
# latest
|
140 |
+
version = 1
|
141 |
+
if version == 0:
|
142 |
+
diff = 1
|
143 |
+
if version >= 1:
|
144 |
+
kwargs.update(dict(system_prompt=system_prompt))
|
145 |
+
diff = 0
|
146 |
+
|
147 |
+
from evaluate_params import eval_func_param_names
|
148 |
+
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == diff
|
149 |
+
if chat:
|
150 |
+
# add chatbot output on end. Assumes serialize=False
|
151 |
+
kwargs.update(dict(chatbot=[]))
|
152 |
+
|
153 |
+
return kwargs, list(kwargs.values())
|
154 |
+
|
155 |
+
|
156 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
157 |
+
def test_client_basic(prompt_type='human_bot', version=None, visible_models=None, prompt='Who are you?',
|
158 |
+
h2ogpt_key=None):
|
159 |
+
return run_client_nochat(prompt=prompt, prompt_type=prompt_type, max_new_tokens=50, version=version,
|
160 |
+
visible_models=visible_models, h2ogpt_key=h2ogpt_key)
|
161 |
+
|
162 |
+
|
163 |
+
"""
|
164 |
+
time HOST=https://gpt-internal.h2o.ai PYTHONPATH=. pytest -n 20 src/client_test.py::test_client_basic_benchmark
|
165 |
+
32 seconds to answer 20 questions at once with 70B llama2 on 4x A100 80GB using TGI 0.9.3
|
166 |
+
"""
|
167 |
+
|
168 |
+
|
169 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
170 |
+
@pytest.mark.parametrize("id", range(20))
|
171 |
+
def test_client_basic_benchmark(id, prompt_type='human_bot', version=None):
|
172 |
+
return run_client_nochat(prompt="""
|
173 |
+
/nfs4/llm/h2ogpt/h2ogpt/bin/python /home/arno/pycharm-2022.2.2/plugins/python/helpers/pycharm/_jb_pytest_runner.py --target src/client_test.py::test_client_basic
|
174 |
+
Testing started at 8:41 AM ...
|
175 |
+
Launching pytest with arguments src/client_test.py::test_client_basic --no-header --no-summary -q in /nfs4/llm/h2ogpt
|
176 |
+
|
177 |
+
============================= test session starts ==============================
|
178 |
+
collecting ...
|
179 |
+
src/client_test.py:None (src/client_test.py)
|
180 |
+
ImportError while importing test module '/nfs4/llm/h2ogpt/src/client_test.py'.
|
181 |
+
Hint: make sure your test modules/packages have valid Python names.
|
182 |
+
Traceback:
|
183 |
+
h2ogpt/lib/python3.10/site-packages/_pytest/python.py:618: in _importtestmodule
|
184 |
+
mod = import_path(self.path, mode=importmode, root=self.config.rootpath)
|
185 |
+
h2ogpt/lib/python3.10/site-packages/_pytest/pathlib.py:533: in import_path
|
186 |
+
importlib.import_module(module_name)
|
187 |
+
/usr/lib/python3.10/importlib/__init__.py:126: in import_module
|
188 |
+
return _bootstrap._gcd_import(name[level:], package, level)
|
189 |
+
<frozen importlib._bootstrap>:1050: in _gcd_import
|
190 |
+
???
|
191 |
+
<frozen importlib._bootstrap>:1027: in _find_and_load
|
192 |
+
???
|
193 |
+
<frozen importlib._bootstrap>:1006: in _find_and_load_unlocked
|
194 |
+
???
|
195 |
+
<frozen importlib._bootstrap>:688: in _load_unlocked
|
196 |
+
???
|
197 |
+
h2ogpt/lib/python3.10/site-packages/_pytest/assertion/rewrite.py:168: in exec_module
|
198 |
+
exec(co, module.__dict__)
|
199 |
+
src/client_test.py:51: in <module>
|
200 |
+
from enums import DocumentSubset, LangChainAction
|
201 |
+
E ModuleNotFoundError: No module named 'enums'
|
202 |
+
|
203 |
+
|
204 |
+
collected 0 items / 1 error
|
205 |
+
|
206 |
+
=============================== 1 error in 0.14s ===============================
|
207 |
+
ERROR: not found: /nfs4/llm/h2ogpt/src/client_test.py::test_client_basic
|
208 |
+
(no name '/nfs4/llm/h2ogpt/src/client_test.py::test_client_basic' in any of [<Module client_test.py>])
|
209 |
+
|
210 |
+
|
211 |
+
Process finished with exit code 4
|
212 |
+
|
213 |
+
What happened?
|
214 |
+
""", prompt_type=prompt_type, max_new_tokens=100, version=version)
|
215 |
+
|
216 |
+
|
217 |
+
def run_client_nochat(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None, visible_models=None):
|
218 |
+
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens, version=version,
|
219 |
+
visible_models=visible_models, h2ogpt_key=h2ogpt_key)
|
220 |
+
|
221 |
+
api_name = '/submit_nochat'
|
222 |
+
client = get_client(serialize=True)
|
223 |
+
res = client.predict(
|
224 |
+
*tuple(args),
|
225 |
+
api_name=api_name,
|
226 |
+
)
|
227 |
+
print("Raw client result: %s" % res, flush=True)
|
228 |
+
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
229 |
+
response=md_to_text(res))
|
230 |
+
print(res_dict)
|
231 |
+
return res_dict, client
|
232 |
+
|
233 |
+
|
234 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
235 |
+
def test_client_basic_api(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
236 |
+
return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50, version=version,
|
237 |
+
h2ogpt_key=h2ogpt_key)
|
238 |
+
|
239 |
+
|
240 |
+
def run_client_nochat_api(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None):
|
241 |
+
kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens, version=version,
|
242 |
+
h2ogpt_key=h2ogpt_key)
|
243 |
+
|
244 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
245 |
+
client = get_client(serialize=True)
|
246 |
+
res = client.predict(
|
247 |
+
str(dict(kwargs)),
|
248 |
+
api_name=api_name,
|
249 |
+
)
|
250 |
+
print("Raw client result: %s" % res, flush=True)
|
251 |
+
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
252 |
+
response=md_to_text(ast.literal_eval(res)['response']),
|
253 |
+
sources=ast.literal_eval(res)['sources'])
|
254 |
+
print(res_dict)
|
255 |
+
return res_dict, client
|
256 |
+
|
257 |
+
|
258 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
259 |
+
def test_client_basic_api_lean(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
260 |
+
return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50,
|
261 |
+
version=version, h2ogpt_key=h2ogpt_key)
|
262 |
+
|
263 |
+
|
264 |
+
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None):
|
265 |
+
kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key)
|
266 |
+
|
267 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
268 |
+
client = get_client(serialize=True)
|
269 |
+
res = client.predict(
|
270 |
+
str(dict(kwargs)),
|
271 |
+
api_name=api_name,
|
272 |
+
)
|
273 |
+
print("Raw client result: %s" % res, flush=True)
|
274 |
+
res_dict = dict(prompt=kwargs['instruction_nochat'],
|
275 |
+
response=md_to_text(ast.literal_eval(res)['response']),
|
276 |
+
sources=ast.literal_eval(res)['sources'],
|
277 |
+
h2ogpt_key=h2ogpt_key)
|
278 |
+
print(res_dict)
|
279 |
+
return res_dict, client
|
280 |
+
|
281 |
+
|
282 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
283 |
+
def test_client_basic_api_lean_morestuff(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
284 |
+
return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50,
|
285 |
+
version=version, h2ogpt_key=h2ogpt_key)
|
286 |
+
|
287 |
+
|
288 |
+
def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512, version=None,
|
289 |
+
h2ogpt_key=None):
|
290 |
+
kwargs = dict(
|
291 |
+
instruction='',
|
292 |
+
iinput='',
|
293 |
+
context='',
|
294 |
+
stream_output=False,
|
295 |
+
prompt_type=prompt_type,
|
296 |
+
temperature=0.1,
|
297 |
+
top_p=0.75,
|
298 |
+
top_k=40,
|
299 |
+
num_beams=1,
|
300 |
+
max_new_tokens=1024,
|
301 |
+
min_new_tokens=0,
|
302 |
+
early_stopping=False,
|
303 |
+
max_time=20,
|
304 |
+
repetition_penalty=1.0,
|
305 |
+
num_return_sequences=1,
|
306 |
+
do_sample=True,
|
307 |
+
chat=False,
|
308 |
+
instruction_nochat=prompt,
|
309 |
+
iinput_nochat='',
|
310 |
+
langchain_mode='Disabled',
|
311 |
+
add_chat_history_to_context=True,
|
312 |
+
langchain_action=LangChainAction.QUERY.value,
|
313 |
+
langchain_agents=[],
|
314 |
+
top_k_docs=4,
|
315 |
+
document_subset=DocumentSubset.Relevant.name,
|
316 |
+
document_choice=[],
|
317 |
+
h2ogpt_key=h2ogpt_key,
|
318 |
+
add_search_to_context=False,
|
319 |
+
)
|
320 |
+
|
321 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
322 |
+
client = get_client(serialize=True)
|
323 |
+
res = client.predict(
|
324 |
+
str(dict(kwargs)),
|
325 |
+
api_name=api_name,
|
326 |
+
)
|
327 |
+
print("Raw client result: %s" % res, flush=True)
|
328 |
+
res_dict = dict(prompt=kwargs['instruction_nochat'],
|
329 |
+
response=md_to_text(ast.literal_eval(res)['response']),
|
330 |
+
sources=ast.literal_eval(res)['sources'],
|
331 |
+
h2ogpt_key=h2ogpt_key)
|
332 |
+
print(res_dict)
|
333 |
+
return res_dict, client
|
334 |
+
|
335 |
+
|
336 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
337 |
+
def test_client_chat(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
338 |
+
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
339 |
+
langchain_mode='Disabled',
|
340 |
+
langchain_action=LangChainAction.QUERY.value,
|
341 |
+
langchain_agents=[],
|
342 |
+
version=version,
|
343 |
+
h2ogpt_key=h2ogpt_key)
|
344 |
+
|
345 |
+
|
346 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
347 |
+
def test_client_chat_stream(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
348 |
+
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
349 |
+
stream_output=True, max_new_tokens=512,
|
350 |
+
langchain_mode='Disabled',
|
351 |
+
langchain_action=LangChainAction.QUERY.value,
|
352 |
+
langchain_agents=[],
|
353 |
+
version=version,
|
354 |
+
h2ogpt_key=h2ogpt_key)
|
355 |
+
|
356 |
+
|
357 |
+
def run_client_chat(prompt='',
|
358 |
+
stream_output=None,
|
359 |
+
max_new_tokens=128,
|
360 |
+
langchain_mode='Disabled',
|
361 |
+
langchain_action=LangChainAction.QUERY.value,
|
362 |
+
langchain_agents=[],
|
363 |
+
prompt_type=None, prompt_dict=None,
|
364 |
+
version=None,
|
365 |
+
h2ogpt_key=None):
|
366 |
+
client = get_client(serialize=False)
|
367 |
+
|
368 |
+
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
369 |
+
max_new_tokens=max_new_tokens,
|
370 |
+
langchain_mode=langchain_mode,
|
371 |
+
langchain_action=langchain_action,
|
372 |
+
langchain_agents=langchain_agents,
|
373 |
+
prompt_dict=prompt_dict,
|
374 |
+
version=version,
|
375 |
+
h2ogpt_key=h2ogpt_key)
|
376 |
+
return run_client(client, prompt, args, kwargs)
|
377 |
+
|
378 |
+
|
379 |
+
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
380 |
+
assert kwargs['chat'], "Chat mode only"
|
381 |
+
res = client.predict(*tuple(args), api_name='/instruction')
|
382 |
+
args[-1] += [res[-1]]
|
383 |
+
|
384 |
+
res_dict = kwargs
|
385 |
+
res_dict['prompt'] = prompt
|
386 |
+
if not kwargs['stream_output']:
|
387 |
+
res = client.predict(*tuple(args), api_name='/instruction_bot')
|
388 |
+
res_dict['response'] = res[0][-1][1]
|
389 |
+
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
|
390 |
+
return res_dict, client
|
391 |
+
else:
|
392 |
+
job = client.submit(*tuple(args), api_name='/instruction_bot')
|
393 |
+
res1 = ''
|
394 |
+
while not job.done():
|
395 |
+
outputs_list = job.communicator.job.outputs
|
396 |
+
if outputs_list:
|
397 |
+
res = job.communicator.job.outputs[-1]
|
398 |
+
res1 = res[0][-1][-1]
|
399 |
+
res1 = md_to_text(res1, do_md_to_text=do_md_to_text)
|
400 |
+
print(res1)
|
401 |
+
time.sleep(0.1)
|
402 |
+
full_outputs = job.outputs()
|
403 |
+
if verbose:
|
404 |
+
print('job.outputs: %s' % str(full_outputs))
|
405 |
+
# ensure get ending to avoid race
|
406 |
+
# -1 means last response if streaming
|
407 |
+
# 0 means get text_output, ignore exception_text
|
408 |
+
# 0 means get list within text_output that looks like [[prompt], [answer]]
|
409 |
+
# 1 means get bot answer, so will have last bot answer
|
410 |
+
res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text)
|
411 |
+
return res_dict, client
|
412 |
+
|
413 |
+
|
414 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
415 |
+
def test_client_nochat_stream(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
416 |
+
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
417 |
+
stream_output=True, max_new_tokens=512,
|
418 |
+
langchain_mode='Disabled',
|
419 |
+
langchain_action=LangChainAction.QUERY.value,
|
420 |
+
langchain_agents=[],
|
421 |
+
version=version,
|
422 |
+
h2ogpt_key=h2ogpt_key)
|
423 |
+
|
424 |
+
|
425 |
+
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens,
|
426 |
+
langchain_mode, langchain_action, langchain_agents, version=None,
|
427 |
+
h2ogpt_key=None):
|
428 |
+
client = get_client(serialize=False)
|
429 |
+
|
430 |
+
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
431 |
+
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode,
|
432 |
+
langchain_action=langchain_action, langchain_agents=langchain_agents,
|
433 |
+
version=version, h2ogpt_key=h2ogpt_key)
|
434 |
+
return run_client_gen(client, prompt, args, kwargs)
|
435 |
+
|
436 |
+
|
437 |
+
def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
438 |
+
res_dict = kwargs
|
439 |
+
res_dict['prompt'] = prompt
|
440 |
+
if not kwargs['stream_output']:
|
441 |
+
res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
|
442 |
+
res_dict.update(ast.literal_eval(res))
|
443 |
+
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
|
444 |
+
return res_dict, client
|
445 |
+
else:
|
446 |
+
job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api')
|
447 |
+
while not job.done():
|
448 |
+
outputs_list = job.communicator.job.outputs
|
449 |
+
if outputs_list:
|
450 |
+
res = job.communicator.job.outputs[-1]
|
451 |
+
res_dict = ast.literal_eval(res)
|
452 |
+
print('Stream: %s' % res_dict['response'])
|
453 |
+
time.sleep(0.1)
|
454 |
+
res_list = job.outputs()
|
455 |
+
assert len(res_list) > 0, "No response, check server"
|
456 |
+
res = res_list[-1]
|
457 |
+
res_dict = ast.literal_eval(res)
|
458 |
+
print('Final: %s' % res_dict['response'])
|
459 |
+
return res_dict, client
|
460 |
+
|
461 |
+
|
462 |
+
def md_to_text(md, do_md_to_text=True):
|
463 |
+
if not do_md_to_text:
|
464 |
+
return md
|
465 |
+
assert md is not None, "Markdown is None"
|
466 |
+
html = markdown.markdown(md)
|
467 |
+
soup = BeautifulSoup(html, features='html.parser')
|
468 |
+
return soup.get_text()
|
469 |
+
|
470 |
+
|
471 |
+
def run_client_many(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
472 |
+
kwargs = dict(prompt_type=prompt_type, version=version, h2ogpt_key=h2ogpt_key)
|
473 |
+
ret1, _ = test_client_chat(**kwargs)
|
474 |
+
ret2, _ = test_client_chat_stream(**kwargs)
|
475 |
+
ret3, _ = test_client_nochat_stream(**kwargs)
|
476 |
+
ret4, _ = test_client_basic(**kwargs)
|
477 |
+
ret5, _ = test_client_basic_api(**kwargs)
|
478 |
+
ret6, _ = test_client_basic_api_lean(**kwargs)
|
479 |
+
ret7, _ = test_client_basic_api_lean_morestuff(**kwargs)
|
480 |
+
return ret1, ret2, ret3, ret4, ret5, ret6, ret7
|
481 |
+
|
482 |
+
|
483 |
+
if __name__ == '__main__':
|
484 |
+
run_client_many()
|
src/create_data.py
ADDED
@@ -0,0 +1,1847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset creation tools.
|
3 |
+
|
4 |
+
Keep to-level imports clean of non-trivial imports for specific tools,
|
5 |
+
because this file is imported for various purposes
|
6 |
+
"""
|
7 |
+
|
8 |
+
import ast
|
9 |
+
import concurrent.futures
|
10 |
+
import contextlib
|
11 |
+
import hashlib
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import shutil
|
15 |
+
import signal
|
16 |
+
import sys
|
17 |
+
import traceback
|
18 |
+
from concurrent.futures import ProcessPoolExecutor
|
19 |
+
|
20 |
+
import psutil
|
21 |
+
import pytest
|
22 |
+
import pandas as pd
|
23 |
+
import numpy as np
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
from utils import flatten_list, remove
|
27 |
+
|
28 |
+
|
29 |
+
def parse_rst_file(filepath):
|
30 |
+
with open(filepath, 'r') as f:
|
31 |
+
input_data = f.read()
|
32 |
+
settings_overrides = {'initial_header_level': 2}
|
33 |
+
from docutils import core
|
34 |
+
document = core.publish_doctree(
|
35 |
+
source=input_data,
|
36 |
+
source_path=filepath,
|
37 |
+
settings_overrides=settings_overrides,
|
38 |
+
)
|
39 |
+
qa_pairs = []
|
40 |
+
current_section = None
|
41 |
+
current_question = ""
|
42 |
+
current_answer = ""
|
43 |
+
for node in document.traverse():
|
44 |
+
if node.__class__.__name__ == 'section':
|
45 |
+
current_section = ""
|
46 |
+
elif current_section is not None:
|
47 |
+
if node.__class__.__name__ == 'Text':
|
48 |
+
if node.astext()[-1] == "?":
|
49 |
+
if current_question:
|
50 |
+
qa_pairs.append((current_question, current_answer))
|
51 |
+
current_question = node.astext()
|
52 |
+
current_answer = ""
|
53 |
+
else:
|
54 |
+
current_answer += node.astext()
|
55 |
+
if current_answer:
|
56 |
+
qa_pairs.append((current_question, current_answer))
|
57 |
+
return {k: v for k, v in qa_pairs}
|
58 |
+
|
59 |
+
|
60 |
+
def test_scrape_dai_docs():
|
61 |
+
home = os.path.expanduser('~')
|
62 |
+
file = os.path.join(home, 'h2oai/docs/faq.rst')
|
63 |
+
qa_pairs = parse_rst_file(file)
|
64 |
+
prompt_type = 'human_bot'
|
65 |
+
from prompter import prompt_types
|
66 |
+
assert prompt_type in prompt_types
|
67 |
+
save_thing = [{"instruction": k, "output": v, 'prompt_type': prompt_type} for k, v in qa_pairs.items()]
|
68 |
+
output_file = "dai_faq.json"
|
69 |
+
with open(output_file, "wt") as f:
|
70 |
+
f.write(json.dumps(save_thing, indent=2))
|
71 |
+
|
72 |
+
|
73 |
+
def test_scrape_dai_docs_all():
|
74 |
+
"""
|
75 |
+
pytest create_data.py::test_scrape_dai_docs_all
|
76 |
+
"""
|
77 |
+
import glob
|
78 |
+
import nltk
|
79 |
+
nltk.download('punkt')
|
80 |
+
dd = {}
|
81 |
+
np.random.seed(1234)
|
82 |
+
home = os.path.expanduser('~')
|
83 |
+
files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst")))
|
84 |
+
np.random.shuffle(files)
|
85 |
+
val_count = int(0.05 * len(files))
|
86 |
+
train_files = files[val_count:]
|
87 |
+
valid_files = files[:val_count]
|
88 |
+
things = [
|
89 |
+
("dai_docs.train.json", train_files),
|
90 |
+
("dai_docs.valid.json", valid_files)
|
91 |
+
]
|
92 |
+
for LEN in [100, 200, 500]:
|
93 |
+
for output_file, ff in things:
|
94 |
+
if output_file not in dd:
|
95 |
+
dd[output_file] = []
|
96 |
+
for f in ff:
|
97 |
+
with open(f) as input:
|
98 |
+
blob = input.read()
|
99 |
+
blob = blob.replace("~~", "")
|
100 |
+
blob = blob.replace("==", "")
|
101 |
+
blob = blob.replace("''", "")
|
102 |
+
blob = blob.replace("--", "")
|
103 |
+
blob = blob.replace("**", "")
|
104 |
+
dd[output_file].extend(get_sentences(blob, length=LEN))
|
105 |
+
for output_file, _ in things:
|
106 |
+
save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in dd[output_file]]
|
107 |
+
with open(output_file, "wt") as f:
|
108 |
+
f.write(json.dumps(save_thing, indent=2))
|
109 |
+
|
110 |
+
|
111 |
+
def get_sentences(blob, length):
|
112 |
+
"""
|
113 |
+
break-up input text into sentences and then output list of sentences of about length in size
|
114 |
+
:param blob:
|
115 |
+
:param length:
|
116 |
+
:return:
|
117 |
+
"""
|
118 |
+
import nltk
|
119 |
+
nltk.download('punkt')
|
120 |
+
from nltk.tokenize import sent_tokenize
|
121 |
+
sentences = sent_tokenize(blob)
|
122 |
+
my_sentences = []
|
123 |
+
my_string = ""
|
124 |
+
for sentence in sentences:
|
125 |
+
if len(my_string) + len(sentence) <= length:
|
126 |
+
if my_string:
|
127 |
+
my_string += " " + sentence
|
128 |
+
else:
|
129 |
+
my_string = sentence
|
130 |
+
else:
|
131 |
+
my_sentences.append(my_string)
|
132 |
+
my_string = ""
|
133 |
+
return my_sentences or [my_string]
|
134 |
+
|
135 |
+
|
136 |
+
def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
|
137 |
+
"""
|
138 |
+
Only supported if have access to source code or HF token for HF spaces and from_hf=True
|
139 |
+
:param path:
|
140 |
+
:param dst:
|
141 |
+
:param from_hf:
|
142 |
+
:return:
|
143 |
+
"""
|
144 |
+
|
145 |
+
home = os.path.expanduser('~')
|
146 |
+
|
147 |
+
if from_hf:
|
148 |
+
# assumes
|
149 |
+
from huggingface_hub import hf_hub_download
|
150 |
+
# True for case when locally already logged in with correct token, so don't have to set key
|
151 |
+
token = os.getenv('HUGGING_FACE_HUB_TOKEN', True)
|
152 |
+
path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.zip', token=token, repo_type='dataset')
|
153 |
+
path = 'h2oai'
|
154 |
+
import zipfile
|
155 |
+
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
|
156 |
+
zip_ref.extractall(path)
|
157 |
+
path = os.path.join(path, 'docs/**/*')
|
158 |
+
|
159 |
+
if path is None:
|
160 |
+
if os.path.isdir(os.path.join(home, 'h2oai')):
|
161 |
+
path = os.path.join(home, "h2oai/docs/**/*")
|
162 |
+
else:
|
163 |
+
assert os.path.isdir(os.path.join(home, 'h2oai.superclean')), '%s does not exist' % path
|
164 |
+
path = os.path.join(home, "h2oai.superclean/docs/**/*")
|
165 |
+
import glob
|
166 |
+
files = list(glob.glob(path, recursive=True))
|
167 |
+
|
168 |
+
# pandoc can't find include files
|
169 |
+
|
170 |
+
remove(dst)
|
171 |
+
os.makedirs(dst)
|
172 |
+
|
173 |
+
# copy full tree, for absolute paths in rst
|
174 |
+
for fil in files:
|
175 |
+
if os.path.isfile(fil):
|
176 |
+
shutil.copy(fil, dst)
|
177 |
+
|
178 |
+
# hack for relative path
|
179 |
+
scorers_dir = os.path.join(dst, 'scorers')
|
180 |
+
makedirs(scorers_dir)
|
181 |
+
for fil in glob.glob(os.path.join(dst, '*.frag')):
|
182 |
+
shutil.copy(fil, scorers_dir)
|
183 |
+
|
184 |
+
return dst
|
185 |
+
|
186 |
+
|
187 |
+
def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30):
|
188 |
+
# account for sequence length (context window) including prompt and input and output
|
189 |
+
|
190 |
+
# os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
|
191 |
+
import pypandoc
|
192 |
+
basedir = os.path.abspath(os.getcwd())
|
193 |
+
|
194 |
+
outputs = []
|
195 |
+
for fil in files:
|
196 |
+
os.chdir(basedir)
|
197 |
+
os.chdir(os.path.dirname(fil))
|
198 |
+
fil = os.path.basename(fil)
|
199 |
+
print("Processing %s" % fil, flush=True)
|
200 |
+
# out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x,
|
201 |
+
# context, csljson, docbook, docbook4, docbook5, docx, dokuwiki,
|
202 |
+
# dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml,
|
203 |
+
# ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira,
|
204 |
+
# json, latex, man,
|
205 |
+
# markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict,
|
206 |
+
# mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx,
|
207 |
+
# revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki
|
208 |
+
out_format = 'plain'
|
209 |
+
# avoid extra new lines injected into text
|
210 |
+
extra_args = ['--wrap=preserve', '--resource path="%s" % dst']
|
211 |
+
|
212 |
+
plain_list = []
|
213 |
+
try:
|
214 |
+
# valid for expert settings
|
215 |
+
input_rst = pypandoc.convert_file(fil, 'rst')
|
216 |
+
input_list = input_rst.split('\n``')
|
217 |
+
for input_subrst in input_list:
|
218 |
+
input_plain = pypandoc.convert_text(input_subrst, format='rst', to='plain')
|
219 |
+
plain_list.append([input_plain, fil])
|
220 |
+
except Exception as e:
|
221 |
+
print("file exception: %s %s" % (fil, str(e)), flush=True)
|
222 |
+
|
223 |
+
if not plain_list:
|
224 |
+
# if failed to process as pieces of rst, then
|
225 |
+
output = pypandoc.convert_file(fil, out_format, extra_args=extra_args, format='rst')
|
226 |
+
outputs1 = get_sentences(output, length=max_len)
|
227 |
+
for oi, output in enumerate(outputs1):
|
228 |
+
output = output.replace('\n\n', '\n')
|
229 |
+
plain_list.append([output, fil])
|
230 |
+
outputs.extend(plain_list)
|
231 |
+
|
232 |
+
# report:
|
233 |
+
# [print(len(x)) for x in outputs]
|
234 |
+
|
235 |
+
# deal with blocks longer than context size (sequence length) of 2048
|
236 |
+
new_outputs = []
|
237 |
+
num_truncated = 0
|
238 |
+
num_orig = len(outputs)
|
239 |
+
for output, fil in outputs:
|
240 |
+
if len(output) < max_len:
|
241 |
+
new_outputs.append([output, fil])
|
242 |
+
continue
|
243 |
+
outputs1 = get_sentences(output, length=max_len)
|
244 |
+
for oi, output1 in enumerate(outputs1):
|
245 |
+
output1 = output1.replace('\n\n', '\n')
|
246 |
+
new_outputs.append([output1, fil])
|
247 |
+
num_truncated += 1
|
248 |
+
print('num_orig: %s num_truncated: %s' % (num_orig, num_truncated), flush=True)
|
249 |
+
|
250 |
+
new_outputs = [[k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len]
|
251 |
+
|
252 |
+
return new_outputs
|
253 |
+
|
254 |
+
|
255 |
+
def test_scrape_dai_docs_all_pandoc():
|
256 |
+
"""
|
257 |
+
pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc
|
258 |
+
:return:
|
259 |
+
"""
|
260 |
+
|
261 |
+
dst = setup_dai_docs()
|
262 |
+
|
263 |
+
import glob
|
264 |
+
files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
|
265 |
+
|
266 |
+
basedir = os.path.abspath(os.getcwd())
|
267 |
+
new_outputs = rst_to_outputs(files)
|
268 |
+
os.chdir(basedir)
|
269 |
+
|
270 |
+
remove(dst)
|
271 |
+
save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in new_outputs]
|
272 |
+
output_file = "dai_docs.train_cleaned.json"
|
273 |
+
with open(output_file, "wt") as f:
|
274 |
+
f.write(json.dumps(save_thing, indent=2))
|
275 |
+
|
276 |
+
|
277 |
+
def test_config_to_json():
|
278 |
+
"""
|
279 |
+
Needs to run from Driverless AI source directory.
|
280 |
+
E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/
|
281 |
+
:return:
|
282 |
+
"""
|
283 |
+
try:
|
284 |
+
# Arrange
|
285 |
+
import json
|
286 |
+
from h2oaicore.systemutils import config
|
287 |
+
toml_list = []
|
288 |
+
for k, v in config.get_meta_dict().items():
|
289 |
+
title = (v.title + ": ") if v.title else ''
|
290 |
+
comment = v.comment or ''
|
291 |
+
if not (title or comment):
|
292 |
+
continue
|
293 |
+
toml_list.extend(
|
294 |
+
[
|
295 |
+
{
|
296 |
+
'prompt_type': 'plain',
|
297 |
+
'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
|
298 |
+
"\n", ""),
|
299 |
+
},
|
300 |
+
{
|
301 |
+
'prompt_type': 'plain',
|
302 |
+
'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace(
|
303 |
+
"\n", ""),
|
304 |
+
},
|
305 |
+
{
|
306 |
+
'prompt_type': 'plain',
|
307 |
+
'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace(
|
308 |
+
"\n", ""),
|
309 |
+
} if title and comment else None,
|
310 |
+
{
|
311 |
+
'prompt_type': 'human_bot',
|
312 |
+
'instruction': f'Explain the following expert setting for Driverless AI',
|
313 |
+
'input': f"{k}",
|
314 |
+
'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
|
315 |
+
},
|
316 |
+
{
|
317 |
+
'prompt_type': 'human_bot',
|
318 |
+
'instruction': f'Explain the following expert setting for Driverless AI',
|
319 |
+
'input': f"{k}",
|
320 |
+
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
321 |
+
},
|
322 |
+
{
|
323 |
+
'prompt_type': 'human_bot',
|
324 |
+
'instruction': f'Explain the following expert setting for Driverless AI',
|
325 |
+
'input': f"{k.replace('_', ' ')}",
|
326 |
+
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
327 |
+
},
|
328 |
+
{
|
329 |
+
'prompt_type': 'human_bot',
|
330 |
+
'instruction': f'Explain the following expert setting for Driverless AI',
|
331 |
+
'input': f"{title}",
|
332 |
+
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
333 |
+
},
|
334 |
+
{
|
335 |
+
'prompt_type': 'human_bot',
|
336 |
+
'instruction': f'Provide a short explanation of the expert setting {k}',
|
337 |
+
'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
|
338 |
+
},
|
339 |
+
{
|
340 |
+
'prompt_type': 'human_bot',
|
341 |
+
'instruction': f'Provide a detailed explanation of the expert setting {k}',
|
342 |
+
'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
|
343 |
+
},
|
344 |
+
]
|
345 |
+
)
|
346 |
+
toml_list = [x for x in toml_list if x]
|
347 |
+
with open("config.json", "wt") as f:
|
348 |
+
f.write(json.dumps(toml_list, indent=2))
|
349 |
+
except Exception as e:
|
350 |
+
print("Exception: %s" % str(e), flush=True)
|
351 |
+
|
352 |
+
|
353 |
+
def copy_tree(src, dst, follow_symlink=False):
|
354 |
+
makedirs(dst, exist_ok=True)
|
355 |
+
for (path, dirs, files) in os.walk(src, followlinks=follow_symlink):
|
356 |
+
new_path = path.replace(src, dst)
|
357 |
+
makedirs(new_path, exist_ok=True)
|
358 |
+
for file in files:
|
359 |
+
filename = os.path.join(path, file)
|
360 |
+
new_filename = os.path.join(new_path, file)
|
361 |
+
# print("%s -> %s" % (filename, new_filename))
|
362 |
+
try:
|
363 |
+
atomic_copy(filename, new_filename)
|
364 |
+
except FileNotFoundError:
|
365 |
+
pass
|
366 |
+
|
367 |
+
|
368 |
+
def atomic_move(src, dst):
|
369 |
+
try:
|
370 |
+
shutil.move(src, dst)
|
371 |
+
except (shutil.Error, FileExistsError):
|
372 |
+
pass
|
373 |
+
remove(src)
|
374 |
+
|
375 |
+
|
376 |
+
def atomic_copy(src=None, dst=None, with_permissions=True):
|
377 |
+
if os.path.isfile(dst):
|
378 |
+
return
|
379 |
+
import uuid
|
380 |
+
my_uuid = uuid.uuid4()
|
381 |
+
dst_tmp = dst + str(my_uuid)
|
382 |
+
makedirs(os.path.dirname(dst), exist_ok=True)
|
383 |
+
if with_permissions:
|
384 |
+
shutil.copy(src, dst_tmp)
|
385 |
+
else:
|
386 |
+
shutil.copyfile(src, dst_tmp)
|
387 |
+
atomic_move(dst_tmp, dst)
|
388 |
+
remove(dst_tmp)
|
389 |
+
|
390 |
+
|
391 |
+
def makedirs(path, exist_ok=True):
|
392 |
+
"""
|
393 |
+
Avoid some inefficiency in os.makedirs()
|
394 |
+
:param path:
|
395 |
+
:param exist_ok:
|
396 |
+
:return:
|
397 |
+
"""
|
398 |
+
if os.path.isdir(path) and os.path.exists(path):
|
399 |
+
assert exist_ok, "Path already exists"
|
400 |
+
return path
|
401 |
+
os.makedirs(path, exist_ok=exist_ok)
|
402 |
+
|
403 |
+
|
404 |
+
## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json
|
405 |
+
## Turn into simple instruct prompt type. No context/previous conversations.
|
406 |
+
def test_prep_instruct_vicuna():
|
407 |
+
from datasets import load_dataset
|
408 |
+
filename = 'ShareGPT_unfiltered_cleaned_split.json'
|
409 |
+
if not os.path.exists(filename):
|
410 |
+
os.system(
|
411 |
+
'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
|
412 |
+
data = load_dataset("json", data_files={"train": filename})["train"]
|
413 |
+
training_rows = []
|
414 |
+
for i in range(data.num_rows):
|
415 |
+
conversations = data[i]['conversations']
|
416 |
+
assert isinstance(conversations, list), conversations
|
417 |
+
convo = ""
|
418 |
+
for j, conv in enumerate(conversations):
|
419 |
+
# Get ready for generate.py prompt_type=human_bot
|
420 |
+
# But train with prompt_type=plain
|
421 |
+
if conv['from'] == 'human':
|
422 |
+
FROM = '<human>: '
|
423 |
+
elif conv['from'] == 'gpt':
|
424 |
+
FROM = '<bot>: '
|
425 |
+
convo += f"{FROM}" + conv['value'] + "\n"
|
426 |
+
if convo:
|
427 |
+
training_rows.append(dict(input=convo))
|
428 |
+
with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
|
429 |
+
f.write(json.dumps(training_rows, indent=2))
|
430 |
+
|
431 |
+
|
432 |
+
POSTFIX = ".generate_human_bot.train_plain.json"
|
433 |
+
|
434 |
+
# https://bair.berkeley.edu/blog/2023/04/03/koala/
|
435 |
+
OIG_DATASETS = [
|
436 |
+
"unified_chip2.jsonl",
|
437 |
+
"unified_grade_school_math_instructions.jsonl",
|
438 |
+
"unified_poetry_2_song.jsonl",
|
439 |
+
"unified_plot_screenplay_books_dialog.jsonl",
|
440 |
+
]
|
441 |
+
|
442 |
+
# hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4
|
443 |
+
ALL_OIG_DATASETS = ['unified_abstract_infill.jsonl',
|
444 |
+
'unified_basic.jsonl',
|
445 |
+
'unified_canadian_parliament.jsonl',
|
446 |
+
'unified_chip2.jsonl',
|
447 |
+
'unified_conv_finqa.jsonl',
|
448 |
+
'unified_cuad.jsonl',
|
449 |
+
'unified_essays.jsonl',
|
450 |
+
'unified_flan.jsonl.gz',
|
451 |
+
'unified_grade_school_math_instructions.jsonl',
|
452 |
+
'unified_hc3_human.jsonl',
|
453 |
+
'unified_image_prompts_instructions.jsonl',
|
454 |
+
'unified_joke_explanations.jsonl',
|
455 |
+
'unified_mathqa_flanv2_kojma_cot.jsonl',
|
456 |
+
'unified_merged_code_xp3.jsonl',
|
457 |
+
'unified_multi_news.jsonl',
|
458 |
+
'unified_multi_sum.jsonl',
|
459 |
+
'unified_ni.jsonl.gz',
|
460 |
+
'unified_nq.jsonl',
|
461 |
+
'unified_openai_summarize_tldr.jsonl',
|
462 |
+
'unified_oscar_en_sample_dialog.jsonl',
|
463 |
+
'unified_p3.jsonl.gz',
|
464 |
+
'unified_plot_screenplay_books_dialog.jsonl',
|
465 |
+
'unified_poetry_2_song.jsonl',
|
466 |
+
'unified_poetry_instructions.jsonl',
|
467 |
+
'unified_rallio_safety_and_prosocial.jsonl',
|
468 |
+
'unified_rallio_soda_upgraded_2048.jsonl',
|
469 |
+
'unified_soda_dialog.jsonl',
|
470 |
+
'unified_sqlv1.jsonl',
|
471 |
+
'unified_sqlv2.jsonl',
|
472 |
+
'unified_squad_v2.jsonl',
|
473 |
+
'unified_squad_v2_more_neg.jsonl',
|
474 |
+
'unified_ul2_plus_oscar_en_sample_dialog.jsonl',
|
475 |
+
'unified_unifiedskg_instructions.jsonl',
|
476 |
+
'unified_unnatural_instructions.jsonl',
|
477 |
+
'unified_xp3_sample.jsonl']
|
478 |
+
|
479 |
+
useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
|
480 |
+
'unified_chip2.jsonl.parquet',
|
481 |
+
'unified_cuad.jsonl.parquet',
|
482 |
+
'unified_essays.jsonl.parquet',
|
483 |
+
'unified_flan.jsonl.gz.parquet',
|
484 |
+
'unified_grade_school_math_instructions.jsonl.parquet',
|
485 |
+
'unified_hc3_human.jsonl.parquet',
|
486 |
+
'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
|
487 |
+
'unified_merged_code_xp3.jsonl.parquet',
|
488 |
+
'unified_multi_news.jsonl.parquet',
|
489 |
+
# 'unified_multi_sum.jsonl.parquet'
|
490 |
+
'unified_ni.jsonl.gz.parquet',
|
491 |
+
'unified_openai_summarize_tldr.jsonl.parquet',
|
492 |
+
# 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
|
493 |
+
'unified_plot_screenplay_books_dialog.jsonl.parquet',
|
494 |
+
'unified_soda_dialog.jsonl.parquet',
|
495 |
+
'unified_unnatural_instructions.jsonl.parquet',
|
496 |
+
]
|
497 |
+
|
498 |
+
|
499 |
+
@pytest.mark.parametrize("filename", OIG_DATASETS)
|
500 |
+
def test_get_small_sample_oig_data(filename):
|
501 |
+
if not os.path.exists(filename):
|
502 |
+
os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
|
503 |
+
import json
|
504 |
+
rows = []
|
505 |
+
with open(filename, "r") as f:
|
506 |
+
for line in f.readlines():
|
507 |
+
row = json.loads(line)
|
508 |
+
rows.append(dict(input=row["text"]))
|
509 |
+
with open(filename + POSTFIX, "w") as f:
|
510 |
+
f.write(json.dumps(rows, indent=2))
|
511 |
+
|
512 |
+
|
513 |
+
@pytest.mark.parametrize("filename", ALL_OIG_DATASETS)
|
514 |
+
def test_download_useful_data_as_parquet(filename):
|
515 |
+
dest_file = filename + '.parquet'
|
516 |
+
if dest_file not in useful_oig_files:
|
517 |
+
pytest.skip('file declared not useful')
|
518 |
+
if not os.path.exists(filename):
|
519 |
+
os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
|
520 |
+
if not os.path.exists(dest_file):
|
521 |
+
df = pd.read_json(path_or_buf=filename, lines=True)
|
522 |
+
df.to_parquet(dest_file, index=False)
|
523 |
+
|
524 |
+
|
525 |
+
def test_merge_shuffle_small_sample_oig_data():
|
526 |
+
np.random.seed(1234)
|
527 |
+
rows = []
|
528 |
+
for filename in OIG_DATASETS:
|
529 |
+
with open(filename + POSTFIX, "r") as f:
|
530 |
+
rows.extend(json.loads(f.read()))
|
531 |
+
np.random.shuffle(rows)
|
532 |
+
with open("merged_shuffled_OIG_%s.json" % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], "w") as f:
|
533 |
+
f.write(json.dumps(rows, indent=2))
|
534 |
+
|
535 |
+
|
536 |
+
def test_join_jsons():
|
537 |
+
files = ['config.json'] * 1 + \
|
538 |
+
['dai_docs.train_cleaned.json'] * 2 + \
|
539 |
+
['dai_faq.json'] * 3
|
540 |
+
print(files)
|
541 |
+
lst = []
|
542 |
+
[lst.extend(json.load(open(fil, 'rt'))) for fil in files]
|
543 |
+
print(len(lst))
|
544 |
+
json.dump(lst, open("merged.json", "wt"), indent=2)
|
545 |
+
|
546 |
+
|
547 |
+
@pytest.mark.parametrize("filename", ['Anthropic/hh-rlhf'])
|
548 |
+
def test_make_rlhf_good_data(filename):
|
549 |
+
from datasets import load_dataset
|
550 |
+
rows = load_dataset(filename)["train"]["chosen"]
|
551 |
+
new_rows = []
|
552 |
+
for row in rows:
|
553 |
+
if row[:2] == "\n\n":
|
554 |
+
row = row[2:]
|
555 |
+
row = row.replace("Human: ", "<human>: ")
|
556 |
+
row = row.replace("Assistant: ", "<bot>: ")
|
557 |
+
new_rows.append(dict(input=row))
|
558 |
+
with open(filename.replace("/", "_") + POSTFIX, "w") as f:
|
559 |
+
f.write(json.dumps(new_rows, indent=2))
|
560 |
+
|
561 |
+
|
562 |
+
def test_show_prompts():
|
563 |
+
files = ['config.json'] * 1 + \
|
564 |
+
['dai_docs.train_cleaned.json'] * 1 + \
|
565 |
+
['dai_faq.json'] * 1
|
566 |
+
file_points = [json.load(open(fil, 'rt')) for fil in files]
|
567 |
+
from prompter import generate_prompt
|
568 |
+
for data_points in file_points:
|
569 |
+
for data_point in data_points:
|
570 |
+
print(generate_prompt(data_point, 'plain', '', False, False, False)[0])
|
571 |
+
|
572 |
+
|
573 |
+
def test_get_open_datasets():
|
574 |
+
# HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter
|
575 |
+
open_tags = ['license:Apache License 2.0',
|
576 |
+
'license:mit',
|
577 |
+
'license:apache',
|
578 |
+
'license:apache2',
|
579 |
+
'license:apache-2.0',
|
580 |
+
'license:bsd',
|
581 |
+
'license:bsd-2-clause',
|
582 |
+
'license:bsd-3-clause',
|
583 |
+
'license:bsd-3-clause-clear',
|
584 |
+
'license:lgpl-2.1',
|
585 |
+
'license:lgpl-3.0',
|
586 |
+
'license:lgpl-lr',
|
587 |
+
'license:lgpl',
|
588 |
+
'license:openrail++',
|
589 |
+
'license:openrail',
|
590 |
+
'license:bigscience-bloom-rail-1.0',
|
591 |
+
# 'license:agpl-3.0',
|
592 |
+
'license:other',
|
593 |
+
'license:unknown',
|
594 |
+
# 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
|
595 |
+
# Attribution required:
|
596 |
+
'license:odc-by',
|
597 |
+
'license:cc-by-4.0',
|
598 |
+
'license:cc-by-3.0',
|
599 |
+
'license:cc-by-2.0',
|
600 |
+
'license:cc-by-2.5',
|
601 |
+
# 'license:cc-by-sa-4.0', # would require same license
|
602 |
+
'license:odbl',
|
603 |
+
'license:pddl',
|
604 |
+
'license:ms-pl',
|
605 |
+
'license:zlib',
|
606 |
+
]
|
607 |
+
# bad license: cc-by-nc-4.0
|
608 |
+
|
609 |
+
from huggingface_hub import list_datasets
|
610 |
+
datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
|
611 |
+
datasets += [x for x in list_datasets(author='openai')]
|
612 |
+
# check all:
|
613 |
+
all_license_tags = set(flatten_list([[y for y in x.tags if 'license' in y] for x in datasets]))
|
614 |
+
print(len(all_license_tags))
|
615 |
+
open_datasets = [x for x in datasets if any([y in x.tags for y in open_tags]) or 'license:' not in str(x.tags)]
|
616 |
+
print('open_datasets', len(open_datasets))
|
617 |
+
all_task_tags = set(flatten_list([[y for y in x.tags if 'task' in y] for x in open_datasets]))
|
618 |
+
print('all_task_tags', len(all_task_tags))
|
619 |
+
excluded_tags = ['image', 'hate', 'tabular', 'table-', 'classification', 'retrieval',
|
620 |
+
'translation', 'identification', 'object', 'mask', 'to-text',
|
621 |
+
'face-detection', 'audio', 'voice', 'reinforcement', 'depth-est',
|
622 |
+
'forecasting', 'parsing', 'visual', 'speech', 'multiple-choice',
|
623 |
+
'slot-filling', 'irds/argsme', '-scoring', 'other', 'graph-ml',
|
624 |
+
'feature-extraction', 'keyword-spotting',
|
625 |
+
'coreference-resolution', 'segmentation',
|
626 |
+
'word-sense-disambiguation',
|
627 |
+
'lemmatization']
|
628 |
+
task_tags = [x.replace('task_categories:', '').replace('task_ids:', '')
|
629 |
+
for x in all_task_tags if not any([y in x for y in
|
630 |
+
excluded_tags])]
|
631 |
+
print('task_tags', len(task_tags))
|
632 |
+
# str(x.tags) to catch any pattern match to anything in list
|
633 |
+
open_tasked_datasets = [x for x in open_datasets if
|
634 |
+
any([y in str([x for x in x.tags if 'task' in x]) for y in task_tags]) and
|
635 |
+
not any([y in str([x for x in x.tags if 'task' in x]) for y in excluded_tags]) or
|
636 |
+
'task_categories' not in str(x.tags) and 'task_ids' not in str(x.tags)]
|
637 |
+
open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled]
|
638 |
+
open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated]
|
639 |
+
open_tasked_datasets = [x for x in open_tasked_datasets if not x.private]
|
640 |
+
print('open_tasked_datasets', len(open_tasked_datasets))
|
641 |
+
sizes = list(set(flatten_list([[(y, x.id) for y in x.tags if 'size' in y] for x in open_tasked_datasets])))
|
642 |
+
languages = list(set(flatten_list([[(y, x.id) for y in x.tags if 'language:' in y] for x in open_tasked_datasets])))
|
643 |
+
open_english_tasked_datasets = [x for x in open_tasked_datasets if
|
644 |
+
'language:' not in str(x.tags) or
|
645 |
+
'language:en' in str(x.tags)]
|
646 |
+
small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
|
647 |
+
'n<1K' in str(x.tags) or
|
648 |
+
'1K<n<10K' in str(x.tags) or
|
649 |
+
'1K0<n<100K' in str(x.tags) or
|
650 |
+
'100K<n<1M' in str(x.tags) or
|
651 |
+
'size_category' not in str(x.tags)
|
652 |
+
]
|
653 |
+
# 'aeslc' : email_body, subject -> summarization?
|
654 |
+
# load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
|
655 |
+
ids = [x.id for x in small_open_english_tasked_datasets]
|
656 |
+
|
657 |
+
# sanity checks
|
658 |
+
# https://bair.berkeley.edu/blog/2023/04/03/koala/
|
659 |
+
assert 'alespalla/chatbot_instruction_prompts' in ids
|
660 |
+
assert 'laion/OIG' in ids
|
661 |
+
assert 'openai/webgpt_comparisons' in ids
|
662 |
+
assert 'openai/summarize_from_feedback' in ids
|
663 |
+
assert 'Anthropic/hh-rlhf' in ids
|
664 |
+
|
665 |
+
# useful but not allowed for commercial purposes:
|
666 |
+
# https://huggingface.co/datasets/squad
|
667 |
+
|
668 |
+
print('open_english_tasked_datasets: ', ids, flush=True)
|
669 |
+
|
670 |
+
exclude_ids = ['allenai/nllb', # translation only
|
671 |
+
'hf-internal-testing/fixtures_image_utils', # testing
|
672 |
+
'allenai/c4', # search-url
|
673 |
+
'agemagician/uniref50', # unknown
|
674 |
+
'huggingface-course/documentation-images', # images
|
675 |
+
'smilegate-ai/kor_unsmile', # korean
|
676 |
+
'MohamedRashad/ChatGPT-prompts', # ChatGPT/LearnGPT/https://www.emergentmind.com/
|
677 |
+
'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
|
678 |
+
'Jeska/vaccinchat', # not useful
|
679 |
+
'alespalla/chatbot_instruction_prompts', # mixes alpaca
|
680 |
+
'allenai/prosocial-dialog',
|
681 |
+
# already exlucded, but wrongly in other datasets that say more permissive license
|
682 |
+
'AlekseyKorshuk/persona-chat', # low quality
|
683 |
+
'bavard/personachat_truecased', # low quality
|
684 |
+
'adamlin/daily_dialog', # medium quality conversations
|
685 |
+
'adamlin/FewShotWoz', # low quality
|
686 |
+
'benjaminbeilharz/better_daily_dialog', # low quality
|
687 |
+
'benjaminbeilharz/daily_dialog_w_turn_templates', # low
|
688 |
+
'benjaminbeilharz/empathetic_dialogues_for_lm', # low
|
689 |
+
'GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915', # NA
|
690 |
+
'ia-bentebib/conv_ai_2_fr', # low fr
|
691 |
+
'ia-bentebib/daily_dialog_fr', # low fr
|
692 |
+
'ia-bentebib/dialog_re_fr', # low fr
|
693 |
+
'ia-bentebib/empathetic_dialogues_fr', # low fr
|
694 |
+
'roskoN/dailydialog', # low
|
695 |
+
'VadorMazer/skyrimdialogstest', # low
|
696 |
+
'bigbio/med_qa', # med specific Q/A
|
697 |
+
'biu-nlp/qa_srl2018', # low quality Q/A
|
698 |
+
'biu-nlp/qa_discourse', # low quality Q/A
|
699 |
+
'iarfmoose/qa_evaluator', # low quality Q/A
|
700 |
+
'jeopardy', # low quality Q/A -- no reasoning
|
701 |
+
'narrativeqa', # low quality Q/A
|
702 |
+
'nomic-ai/gpt4all_prompt_generations', # bad license
|
703 |
+
'nomic-ai/gpt4all_prompt_generations_with_p3', # bad license
|
704 |
+
'HuggingFaceH4/alpaca', # bad license
|
705 |
+
'tatsu-lab/alpaca', # ToS breaking
|
706 |
+
'yahma/alpaca-cleaned', # ToS breaking
|
707 |
+
'Hello-SimpleAI/HC3', # bad license
|
708 |
+
'glue', # no reasoning QA
|
709 |
+
'sahil2801/CodeAlpaca-20k', # bad license
|
710 |
+
'Short-Answer-Feedback/saf_communication_networks_english', # long Q, medium A
|
711 |
+
]
|
712 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if x.id not in exclude_ids]
|
713 |
+
# some ids clearly speech related
|
714 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
|
715 |
+
# HF testing
|
716 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
717 |
+
'hf-internal-testing' not in x.id]
|
718 |
+
small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
|
719 |
+
'chinese' not in x.id]
|
720 |
+
|
721 |
+
sorted_small_open_english_tasked_datasets = sorted([(x.downloads, x) for x in small_open_english_tasked_datasets],
|
722 |
+
key=lambda x: x[0], reverse=True)
|
723 |
+
|
724 |
+
# NOTES:
|
725 |
+
# Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log
|
726 |
+
# See what needs config passed and add:
|
727 |
+
# grep 'load_dataset(' getdata9.log|grep -v data_id|less -S
|
728 |
+
# grep "pip install" getdata9.log
|
729 |
+
# NOTE: Some datasets have default config, but others are there. Don't know how to access them.
|
730 |
+
|
731 |
+
"""
|
732 |
+
https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
|
733 |
+
https://github.com/mahnazkoupaee/WikiHow-Dataset
|
734 |
+
https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
|
735 |
+
https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
|
736 |
+
"""
|
737 |
+
|
738 |
+
"""
|
739 |
+
# some ambiguous or non-commercial datasets
|
740 |
+
https://github.com/PhoebusSi/alpaca-CoT
|
741 |
+
"""
|
742 |
+
|
743 |
+
timeout = 3 * 60
|
744 |
+
# laion/OIG takes longer
|
745 |
+
for num_downloads, dataset in sorted_small_open_english_tasked_datasets:
|
746 |
+
data_id = dataset.id
|
747 |
+
func = do_one
|
748 |
+
args = (data_id, num_downloads)
|
749 |
+
kwargs = {}
|
750 |
+
with ProcessPoolExecutor(max_workers=1) as executor:
|
751 |
+
future = executor.submit(func, *args, **kwargs)
|
752 |
+
try:
|
753 |
+
future.result(timeout=timeout)
|
754 |
+
except concurrent.futures.TimeoutError:
|
755 |
+
print("\n\ndata_id %s timeout\n\n" % data_id, flush=True)
|
756 |
+
for child in psutil.Process(os.getpid()).children(recursive=True):
|
757 |
+
os.kill(child.pid, signal.SIGINT)
|
758 |
+
os.kill(child.pid, signal.SIGTERM)
|
759 |
+
os.kill(child.pid, signal.SIGKILL)
|
760 |
+
|
761 |
+
|
762 |
+
def do_one(data_id, num_downloads):
|
763 |
+
from datasets import load_dataset
|
764 |
+
out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
|
765 |
+
if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024 ** 3:
|
766 |
+
return
|
767 |
+
try:
|
768 |
+
print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
|
769 |
+
avail_list = None
|
770 |
+
try:
|
771 |
+
data = load_dataset(data_id, 'foobar')
|
772 |
+
except Exception as e:
|
773 |
+
if 'Available: ' in str(e):
|
774 |
+
avail_list = ast.literal_eval(str(e).split('Available:')[1].strip())
|
775 |
+
else:
|
776 |
+
avail_list = None
|
777 |
+
if avail_list is None:
|
778 |
+
avail_list = [None]
|
779 |
+
print("%s avail_list: %s" % (data_id, avail_list), flush=True)
|
780 |
+
|
781 |
+
for name in avail_list:
|
782 |
+
out_file = "data_%s_%s.parquet" % (str(data_id.replace('/', '_')), str(name))
|
783 |
+
if os.path.isfile(out_file):
|
784 |
+
continue
|
785 |
+
data = load_dataset(data_id, name)
|
786 |
+
column_names_dict = data.column_names
|
787 |
+
column_names = column_names_dict[list(column_names_dict.keys())[0]]
|
788 |
+
print("Processing data_id %s num_downloads: %s columns: %s" % (data_id, num_downloads, column_names),
|
789 |
+
flush=True)
|
790 |
+
data_dict = data.data
|
791 |
+
col_dict = data.num_columns
|
792 |
+
first_col = list(col_dict.keys())[0]
|
793 |
+
if 'train' in data_dict:
|
794 |
+
df = data['train'].to_pandas()
|
795 |
+
else:
|
796 |
+
df = data[first_col].to_pandas()
|
797 |
+
# csv has issues with escaping chars, even for datasets I know I want
|
798 |
+
df.to_parquet(out_file, index=False)
|
799 |
+
except Exception as e:
|
800 |
+
t, v, tb = sys.exc_info()
|
801 |
+
ex = ''.join(traceback.format_exception(t, v, tb))
|
802 |
+
print("Exception: %s %s" % (data_id, ex), flush=True)
|
803 |
+
|
804 |
+
|
805 |
+
def test_otherlic():
|
806 |
+
from huggingface_hub import list_datasets
|
807 |
+
lic = ['license:odc-by',
|
808 |
+
'license:cc-by-4.0',
|
809 |
+
'license:cc-by-3.0',
|
810 |
+
'license:cc-by-2.0',
|
811 |
+
'license:cc-by-2.5',
|
812 |
+
'license:cc-by-sa-4.0',
|
813 |
+
'license:odbl',
|
814 |
+
'license:pddl',
|
815 |
+
'license:ms-pl',
|
816 |
+
'license:zlib',
|
817 |
+
]
|
818 |
+
datasets = flatten_list([[x for x in list_datasets(filter=y) if 'translation' not in str(x.tags)] for y in lic])
|
819 |
+
print(len(datasets))
|
820 |
+
|
821 |
+
|
822 |
+
# These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile
|
823 |
+
# grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog
|
824 |
+
useful = ['Dahoas/instruct-human-assistant-prompt',
|
825 |
+
'Dahoas/first-instruct-human-assistant-prompt',
|
826 |
+
'knkarthick/dialogsum', # summary of conversation
|
827 |
+
'McGill-NLP/FaithDial', # medium quality
|
828 |
+
'Zaid/quac_expanded', # medium quality context + QA
|
829 |
+
'0-hero/OIG-small-chip2', # medium
|
830 |
+
'alistvt/coqa-flat', # QA medium
|
831 |
+
'AnonymousSub/MedQuAD_47441_Question_Answer_Pairs', # QA medium
|
832 |
+
'Anthropic/hh-rlhf', # high quality # similar to Dahoas/full-hh-rlhf
|
833 |
+
'arjunth2001/online_privacy_qna', # good quality QA
|
834 |
+
'Dahoas/instruct_helpful_preferences', # medium quality instruct
|
835 |
+
'Dahoas/rl-prompt-dataset', # medium chat
|
836 |
+
'Dahoas/rm-static', # medium chat
|
837 |
+
'Dahoas/static-hh', # medium chat # HuggingFaceH4/self_instruct
|
838 |
+
'Dahoas/synthetic-instruct-gptj-pairwise', # medium chat
|
839 |
+
'eli5', # QA if prompt ELI5
|
840 |
+
'gsm8k', # QA (various)
|
841 |
+
'guanaco/guanaco', # prompt/response
|
842 |
+
'kastan/rlhf-qa-comparisons', # good QA
|
843 |
+
'kastan/rlhf-qa-conditional-generation-v2', # prompt answer
|
844 |
+
'OllieStanley/humaneval-mbpp-codegen-qa', # code QA, but started from words, so better than other code QA
|
845 |
+
'OllieStanley/humaneval-mbpp-testgen-qa', # code QA
|
846 |
+
'Graverman/Instruct-to-Code', # code QA
|
847 |
+
'openai/summarize_from_feedback', # summarize
|
848 |
+
'relbert/analogy_questions', # analogy QA
|
849 |
+
'yitingxie/rlhf-reward-datasets', # prompt, chosen, rejected.
|
850 |
+
'yizhongw/self_instruct', # instruct (super natural & instruct)
|
851 |
+
'HuggingFaceH4/asss', # QA, big A
|
852 |
+
'kastan/rlhf-qa-conditional-generation-v2', # QA
|
853 |
+
'cosmos_qa', # context QA
|
854 |
+
'vishal-burman/c4-faqs', # QA but not so much reasoning, but alot of text
|
855 |
+
'squadshifts', # QA from context
|
856 |
+
'hotpot_qa', # QA from context
|
857 |
+
'adversarial_qa', # QA from context
|
858 |
+
'allenai/soda', # dialog -> narrative/summary
|
859 |
+
'squad_v2', # context QA
|
860 |
+
'squadshifts', # context QA
|
861 |
+
'dferndz/cSQuAD1', # context QA
|
862 |
+
'dferndz/cSQuAD2', # context QA
|
863 |
+
'din0s/msmarco-nlgen', # context QA
|
864 |
+
'domenicrosati/TruthfulQA', # common sense truthful QA -- trivia but good trivia
|
865 |
+
'hotpot_qa', # context, QA
|
866 |
+
'HuggingFaceH4/self-instruct-eval', # instruct QA, medium quality, some language reasoning
|
867 |
+
'kastan/EE_QA_for_RLHF', # context QA
|
868 |
+
'KK04/LogicInference_OA', # instruction logical QA
|
869 |
+
'lmqg/qa_squadshifts_synthetic', # context QA
|
870 |
+
'lmqg/qg_squad', # context QA
|
871 |
+
'lmqg/qg_squadshifts', # context QA
|
872 |
+
'lmqg/qg_subjqa', # context QA
|
873 |
+
'pszemraj/HC3-textgen-qa',
|
874 |
+
# QA medium, has human responses -- humans tend to provide links instead of trying to answer
|
875 |
+
'pythonist/newdata', # long context, QA, brief A
|
876 |
+
'ropes', # long background, situation, question, A
|
877 |
+
'wikitablequestions', # table -> QA
|
878 |
+
'bigscience/p3', # context QA but short answers
|
879 |
+
]
|
880 |
+
|
881 |
+
code_useful = ['0n1xus/codexglue',
|
882 |
+
'openai_humaneval',
|
883 |
+
'koutch/staqc',
|
884 |
+
]
|
885 |
+
|
886 |
+
maybe_useful = ['AlekseyKorshuk/comedy-scripts',
|
887 |
+
'openbookqa', # hard to parse, low reasoning
|
888 |
+
'qed', # reasonable QA, but low reasoning
|
889 |
+
'selqa', # candidate answers
|
890 |
+
'HuggingFaceH4/instruction-pilot-outputs-filtered',
|
891 |
+
'GBaker/MedQA-USMLE-4-options', # medical QA with long questions
|
892 |
+
'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
|
893 |
+
]
|
894 |
+
|
895 |
+
summary_useful = ['austin/rheum_abstracts',
|
896 |
+
'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
|
897 |
+
'CarperAI/openai_summarize_tldr', # summarize QA
|
898 |
+
'ccdv/cnn_dailymail', # summarize news
|
899 |
+
'ccdv/govreport-summarization', # summarize high quality
|
900 |
+
'ccdv/pubmed-summarization', # summarize high quality
|
901 |
+
'duorc', # plot -> QA
|
902 |
+
'farleyknight/big_patent_5_percent', # desc -> abstract
|
903 |
+
'multi_news', # summary
|
904 |
+
'opinosis',
|
905 |
+
'SophieTr/reddit_clean',
|
906 |
+
'allenai/mup', # long text -> summary
|
907 |
+
'allenai/multi_lexsum', # long text -> summary
|
908 |
+
'big_patent',
|
909 |
+
'allenai/wcep_dense_max',
|
910 |
+
'awinml/costco_long_practice',
|
911 |
+
'GEM/xsum',
|
912 |
+
'ratishsp/newshead',
|
913 |
+
'RussianNLP/wikiomnia', # russian
|
914 |
+
'stacked-summaries/stacked-xsum-1024',
|
915 |
+
]
|
916 |
+
|
917 |
+
math_useful = [
|
918 |
+
'competition_math'
|
919 |
+
]
|
920 |
+
|
921 |
+
skipped = ['c4', # maybe useful, used for flan, but skipped due to size
|
922 |
+
]
|
923 |
+
|
924 |
+
"""
|
925 |
+
To get training data from oig:
|
926 |
+
pytest test_oig test_grade_final test_finalize_to_json
|
927 |
+
"""
|
928 |
+
|
929 |
+
human = '<human>:'
|
930 |
+
bot = '<bot>:'
|
931 |
+
|
932 |
+
|
933 |
+
def test_assemble_and_detox():
|
934 |
+
import re
|
935 |
+
from profanity_check import predict_prob
|
936 |
+
df_list = []
|
937 |
+
for data in useful_oig_files:
|
938 |
+
print("Processing %s" % data, flush=True)
|
939 |
+
df = pd.read_parquet(data)
|
940 |
+
df = df.reset_index(drop=True)
|
941 |
+
# chop up into human/bot interactions of no more than 10kB per row
|
942 |
+
text_list = df[['text']].values.ravel().tolist()
|
943 |
+
new_text = []
|
944 |
+
max_len = 2048 # uber cutoff
|
945 |
+
MAX_LEN = 2048 // 2 - 30 # max len per question/answer
|
946 |
+
for text in tqdm(text_list):
|
947 |
+
human_starts = [m.start() for m in re.finditer('<human>: ', text)]
|
948 |
+
if len(human_starts) == 1:
|
949 |
+
human_starts = [0, len(text)] # always go into for loop below
|
950 |
+
blurb = ''
|
951 |
+
for i in range(len(human_starts) - 1):
|
952 |
+
interaction = text[human_starts[i]: human_starts[i + 1]][:max_len]
|
953 |
+
blurb += interaction
|
954 |
+
if len(blurb) >= MAX_LEN:
|
955 |
+
blurb = get_sentences(blurb, length=MAX_LEN)[0]
|
956 |
+
new_text.append(blurb + "\n<human>:")
|
957 |
+
blurb = ''
|
958 |
+
if blurb:
|
959 |
+
blurb = get_sentences(blurb, length=MAX_LEN)[0]
|
960 |
+
new_text.append(blurb + "\n<human>:")
|
961 |
+
|
962 |
+
if len(new_text) > len(text_list):
|
963 |
+
print("Added %d new rows (before: %d)" % (len(new_text) - df.shape[0], df.shape[0]))
|
964 |
+
df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)})
|
965 |
+
df = df.drop_duplicates(keep='first')
|
966 |
+
print(df['text'].apply(lambda x: len(x)).describe())
|
967 |
+
assert df['text'].apply(lambda x: len(x)).max() <= 2 * max_len
|
968 |
+
|
969 |
+
# faster than better_profanity, do early
|
970 |
+
df['profanity'] = predict_prob(df['text'])
|
971 |
+
before_rows = df.shape[0]
|
972 |
+
df = df[df['profanity'] < 0.25] # drop any low quality stuff
|
973 |
+
after_rows = df.shape[0]
|
974 |
+
print("Dropped %d rows out of %d due to alt-profanity-check" % (before_rows - after_rows, before_rows))
|
975 |
+
df_list.append(df)
|
976 |
+
print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
|
977 |
+
print("So far have %d rows" % sum([len(x) for x in df_list]))
|
978 |
+
df_final = pd.concat(df_list)
|
979 |
+
df_final = df_final.sample(frac=1, random_state=1234).reset_index(drop=True)
|
980 |
+
df_final.to_parquet('h2oGPT.cleaned.human_bot.shorter.parquet', index=False)
|
981 |
+
|
982 |
+
|
983 |
+
def test_basic_cleaning():
|
984 |
+
# from better_profanity import profanity
|
985 |
+
# https://pypi.org/project/alt-profanity-check/
|
986 |
+
from profanity_check import predict
|
987 |
+
df_list = []
|
988 |
+
for data in useful_oig_files:
|
989 |
+
# for data in useful_oig_files[:5]:
|
990 |
+
# for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
|
991 |
+
print("Processing %s" % data, flush=True)
|
992 |
+
df = pd.read_parquet(data)
|
993 |
+
df = df.reset_index(drop=True)
|
994 |
+
# NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
|
995 |
+
# avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
|
996 |
+
df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot)) / 2.0)
|
997 |
+
df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
|
998 |
+
# df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
|
999 |
+
# low_quality_patterns = ['Write the rest of this wikipedia article']
|
1000 |
+
res = predict(df['text'])
|
1001 |
+
df['bad_words'] = res
|
1002 |
+
df = df.reset_index(drop=True)
|
1003 |
+
df = df[df['bad_words'] == 0]
|
1004 |
+
df = df[['text', 'avg_words', 'avg_bot_words']]
|
1005 |
+
df = df.drop_duplicates(keep='first')
|
1006 |
+
print(df[df['avg_words'] == df['avg_words'].max()]['text'].values)
|
1007 |
+
median_words = np.median(df['avg_words'])
|
1008 |
+
min_words_per_entity = max(30, 0.8 * median_words)
|
1009 |
+
max_words_per_entity = 2048 # too hard to learn from for now
|
1010 |
+
df = df[df['avg_words'] > min_words_per_entity]
|
1011 |
+
df = df[df['avg_words'] < max_words_per_entity]
|
1012 |
+
|
1013 |
+
min_words_per_entity = max(20, 0.5 * median_words) # bot should say stuff for now
|
1014 |
+
max_words_per_entity = 2048 # too hard to learn from for now
|
1015 |
+
df = df[df['avg_bot_words'] > min_words_per_entity]
|
1016 |
+
df = df[df['avg_bot_words'] < max_words_per_entity]
|
1017 |
+
|
1018 |
+
df_list.append(df)
|
1019 |
+
print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
|
1020 |
+
df_final = pd.concat(df_list)
|
1021 |
+
df_final.to_parquet('h2oGPT.cleaned.human_bot.parquet', index=False)
|
1022 |
+
|
1023 |
+
|
1024 |
+
from joblib import Parallel, delayed, effective_n_jobs
|
1025 |
+
from sklearn.utils import gen_even_slices
|
1026 |
+
from sklearn.utils.validation import _num_samples
|
1027 |
+
|
1028 |
+
|
1029 |
+
def parallel_apply(df, func, n_jobs=-1, **kwargs):
|
1030 |
+
""" Pandas apply in parallel using joblib.
|
1031 |
+
Uses sklearn.utils to partition input evenly.
|
1032 |
+
|
1033 |
+
Args:
|
1034 |
+
df: Pandas DataFrame, Series, or any other object that supports slicing and apply.
|
1035 |
+
func: Callable to apply
|
1036 |
+
n_jobs: Desired number of workers. Default value -1 means use all available cores.
|
1037 |
+
**kwargs: Any additional parameters will be supplied to the apply function
|
1038 |
+
|
1039 |
+
Returns:
|
1040 |
+
Same as for normal Pandas DataFrame.apply()
|
1041 |
+
|
1042 |
+
"""
|
1043 |
+
|
1044 |
+
if effective_n_jobs(n_jobs) == 1:
|
1045 |
+
return df.apply(func, **kwargs)
|
1046 |
+
else:
|
1047 |
+
ret = Parallel(n_jobs=n_jobs)(
|
1048 |
+
delayed(type(df).apply)(df[s], func, **kwargs)
|
1049 |
+
for s in gen_even_slices(_num_samples(df), effective_n_jobs(n_jobs)))
|
1050 |
+
return pd.concat(ret)
|
1051 |
+
|
1052 |
+
|
1053 |
+
def add_better_profanity_flag(df):
|
1054 |
+
from better_profanity import profanity
|
1055 |
+
df['better_profanity'] = parallel_apply(
|
1056 |
+
df['text'],
|
1057 |
+
lambda x: profanity.contains_profanity(x),
|
1058 |
+
n_jobs=-1,
|
1059 |
+
)
|
1060 |
+
return df
|
1061 |
+
|
1062 |
+
|
1063 |
+
def add_textstat_grade(df):
|
1064 |
+
import textstat
|
1065 |
+
|
1066 |
+
def myfunc(x):
|
1067 |
+
return textstat.flesch_kincaid_grade(x) # simple grade
|
1068 |
+
|
1069 |
+
if False:
|
1070 |
+
import dask.dataframe as dd
|
1071 |
+
# 40 seconds for 1000 rows, but have 1,787,799 rows
|
1072 |
+
ddata = dd.from_pandas(df, npartitions=120)
|
1073 |
+
|
1074 |
+
df['flesch_grade'] = ddata['text'].apply(myfunc).compute()
|
1075 |
+
if True:
|
1076 |
+
# fast way
|
1077 |
+
df['flesch_grade'] = parallel_apply(df['text'], myfunc, n_jobs=-1)
|
1078 |
+
return df
|
1079 |
+
|
1080 |
+
|
1081 |
+
def add_deberta_grade(df):
|
1082 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
1083 |
+
import torch
|
1084 |
+
reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
|
1085 |
+
rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(
|
1086 |
+
reward_name), AutoTokenizer.from_pretrained(reward_name)
|
1087 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
1088 |
+
rank_model.to(device)
|
1089 |
+
|
1090 |
+
def get_question(x):
|
1091 |
+
return x.replace('<human>: ', '').split('<bot>:')[0]
|
1092 |
+
|
1093 |
+
def get_answer(x):
|
1094 |
+
try:
|
1095 |
+
answer = x.split('<bot>: ')[1].split('<human>:')[0].replace('<bot>: ', '')
|
1096 |
+
except:
|
1097 |
+
answer = x.split('<bot>:')[1].split('<human>:')[0].replace('<bot>:', '')
|
1098 |
+
return answer
|
1099 |
+
|
1100 |
+
df['question'] = parallel_apply(df['text'], get_question, n_jobs=-1)
|
1101 |
+
df['answer'] = parallel_apply(df['text'], get_answer, n_jobs=-1)
|
1102 |
+
|
1103 |
+
from datasets import Dataset
|
1104 |
+
from transformers import pipeline
|
1105 |
+
from transformers.pipelines.pt_utils import KeyPairDataset
|
1106 |
+
import tqdm
|
1107 |
+
|
1108 |
+
pipe = pipeline(
|
1109 |
+
"text-classification",
|
1110 |
+
model=reward_name,
|
1111 |
+
device="cuda:0" if torch.cuda.is_available() else "cpu"
|
1112 |
+
)
|
1113 |
+
start = 0
|
1114 |
+
batch_size = 64 * 16
|
1115 |
+
micro_batch = orig_micro_batch = 16
|
1116 |
+
end = 0
|
1117 |
+
import socket
|
1118 |
+
checkpoint = "grades.%s.pkl" % socket.gethostname()
|
1119 |
+
grades = []
|
1120 |
+
import pickle
|
1121 |
+
if os.path.exists(checkpoint):
|
1122 |
+
with open(checkpoint, "rb") as f:
|
1123 |
+
start, grades = pickle.loads(f.read())
|
1124 |
+
last_oom = 0
|
1125 |
+
while end < df.shape[0]:
|
1126 |
+
# manual batching to handle OOM more gracefully
|
1127 |
+
end = min(start + batch_size, df.shape[0])
|
1128 |
+
if start == end:
|
1129 |
+
break
|
1130 |
+
dataset = Dataset.from_pandas(df.iloc[start:end, :])
|
1131 |
+
try:
|
1132 |
+
grades.extend([
|
1133 |
+
x['score'] for x in tqdm.tqdm(
|
1134 |
+
pipe(KeyPairDataset(dataset, "question", "answer"), batch_size=micro_batch)
|
1135 |
+
)
|
1136 |
+
])
|
1137 |
+
except torch.cuda.OutOfMemoryError:
|
1138 |
+
last_oom = start
|
1139 |
+
micro_batch = max(1, micro_batch // 2)
|
1140 |
+
print("OOM - retrying with micro_batch=%d" % micro_batch)
|
1141 |
+
continue
|
1142 |
+
if last_oom == start:
|
1143 |
+
micro_batch = orig_micro_batch
|
1144 |
+
print("Returning to micro_batch=%d" % micro_batch)
|
1145 |
+
assert len(grades) == end
|
1146 |
+
start = end
|
1147 |
+
with open(checkpoint, "wb") as f:
|
1148 |
+
f.write(pickle.dumps((end, grades)))
|
1149 |
+
print("%d/%d" % (end, df.shape[0]))
|
1150 |
+
df['grade_deberta'] = grades
|
1151 |
+
if os.path.exists(checkpoint):
|
1152 |
+
os.remove(checkpoint)
|
1153 |
+
return df
|
1154 |
+
|
1155 |
+
|
1156 |
+
def test_chop_by_lengths():
|
1157 |
+
file = "h2oGPT.cleaned.human_bot.shorter.parquet"
|
1158 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1159 |
+
df = count_human_bot_lengths(df)
|
1160 |
+
df['rand'] = np.random.rand(df.shape[0])
|
1161 |
+
df['rand2'] = np.random.rand(df.shape[0])
|
1162 |
+
before_rows = df.shape[0]
|
1163 |
+
# throw away short human/bot responses with higher likelihood
|
1164 |
+
df = df[(df['len_human_mean'] > 20)] # never keep very short ones
|
1165 |
+
df = df[(df['len_human_mean'] > 30) | (df['rand'] < 0.2)]
|
1166 |
+
df = df[(df['len_human_mean'] > 50) | (df['rand'] < 0.5)]
|
1167 |
+
df = df[(df['len_human_max'] < 10000)] # drop super long (basically only human) ones
|
1168 |
+
df = df[(df['len_bot_mean'] > 20)] # never keep very short ones
|
1169 |
+
df = df[(df['len_bot_mean'] > 30) | (df['rand2'] < 0.2)]
|
1170 |
+
df = df[(df['len_bot_mean'] > 50) | (df['rand2'] < 0.5)]
|
1171 |
+
df = df[(df['len_bot_max'] < 10000)] # drop super long (only bot) ones
|
1172 |
+
assert df['text'].apply(lambda x: len(x)).max() < 20000
|
1173 |
+
df = df.drop(['rand', 'rand2'], axis=1)
|
1174 |
+
after_rows = df.shape[0]
|
1175 |
+
print("Chopped off %d out of %d rows due to length" % (before_rows - after_rows, before_rows))
|
1176 |
+
print(df.describe())
|
1177 |
+
df.to_parquet('h2oGPT.cleaned.chopped.human_bot.shorter.parquet', index=False)
|
1178 |
+
|
1179 |
+
|
1180 |
+
def count_human_bot_lengths(df, human=None, bot=None):
|
1181 |
+
import re
|
1182 |
+
len_human_min = []
|
1183 |
+
len_human_max = []
|
1184 |
+
len_human_mean = []
|
1185 |
+
len_bot_min = []
|
1186 |
+
len_bot_max = []
|
1187 |
+
len_bot_mean = []
|
1188 |
+
human = human or '<human>:'
|
1189 |
+
bot = bot or '<bot>:'
|
1190 |
+
for is_human in [True, False]:
|
1191 |
+
what = human if is_human else bot
|
1192 |
+
other = human if not is_human else bot
|
1193 |
+
for i in range(df.shape[0]):
|
1194 |
+
text = df.loc[i, 'text']
|
1195 |
+
assert isinstance(text, str)
|
1196 |
+
starts = [m.start() for m in re.finditer(what, text)]
|
1197 |
+
if len(starts) == 1:
|
1198 |
+
starts = [starts[0], len(text)] # always go into for loop below
|
1199 |
+
assert len(text)
|
1200 |
+
list_what = []
|
1201 |
+
for ii in range(len(starts) - 1):
|
1202 |
+
interaction = text[starts[ii]: starts[ii + 1]]
|
1203 |
+
if other in interaction:
|
1204 |
+
interaction = interaction[:interaction.find(other)]
|
1205 |
+
interaction.strip()
|
1206 |
+
list_what.append(interaction)
|
1207 |
+
if not list_what:
|
1208 |
+
list_what = [''] # handle corrupted data, very rare, leads to sizes 0
|
1209 |
+
if is_human:
|
1210 |
+
len_human_min.append(min([len(x) for x in list_what]))
|
1211 |
+
len_human_max.append(max([len(x) for x in list_what]))
|
1212 |
+
len_human_mean.append(np.mean([len(x) for x in list_what]))
|
1213 |
+
else:
|
1214 |
+
len_bot_min.append(min([len(x) for x in list_what]))
|
1215 |
+
len_bot_max.append(max([len(x) for x in list_what]))
|
1216 |
+
len_bot_mean.append(np.mean([len(x) for x in list_what]))
|
1217 |
+
df['len_human_min'] = len_human_min
|
1218 |
+
df['len_human_max'] = len_human_max
|
1219 |
+
df['len_human_mean'] = len_human_mean
|
1220 |
+
df['len_bot_min'] = len_bot_min
|
1221 |
+
df['len_bot_max'] = len_bot_max
|
1222 |
+
df['len_bot_mean'] = len_bot_mean
|
1223 |
+
np.random.seed(1234)
|
1224 |
+
pd.set_option('display.max_columns', None)
|
1225 |
+
print("Before chopping")
|
1226 |
+
print(df.describe())
|
1227 |
+
return df
|
1228 |
+
|
1229 |
+
|
1230 |
+
def test_grade():
|
1231 |
+
df = None
|
1232 |
+
|
1233 |
+
file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet"
|
1234 |
+
output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet"
|
1235 |
+
if not os.path.exists(output_file):
|
1236 |
+
if df is None:
|
1237 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1238 |
+
df = add_textstat_grade(df)
|
1239 |
+
min_grade = 10
|
1240 |
+
max_grade = 25
|
1241 |
+
df = df[df['flesch_grade'] >= min_grade]
|
1242 |
+
df = df[df['flesch_grade'] <= max_grade]
|
1243 |
+
print("After Flesch grade")
|
1244 |
+
print(df.describe())
|
1245 |
+
df.to_parquet(output_file, index=False)
|
1246 |
+
|
1247 |
+
file = output_file
|
1248 |
+
output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet"
|
1249 |
+
if not os.path.exists(output_file):
|
1250 |
+
# slower than alt-profanity, do last, but do before deberta grading, since that's slower
|
1251 |
+
if df is None:
|
1252 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1253 |
+
df = add_better_profanity_flag(df)
|
1254 |
+
before_rows = df.shape[0]
|
1255 |
+
df = df[df['better_profanity'] == 0]
|
1256 |
+
df = df.drop(['better_profanity'], axis=1)
|
1257 |
+
after_rows = df.shape[0]
|
1258 |
+
print("Dropped %d rows out of %d due to better_profanity" % (before_rows - after_rows, before_rows))
|
1259 |
+
print(df.describe())
|
1260 |
+
df.to_parquet(output_file, index=False)
|
1261 |
+
|
1262 |
+
file = output_file
|
1263 |
+
output_file = 'h2oGPT.cleaned.graded3.human_bot.shorter.parquet'
|
1264 |
+
if not os.path.exists(output_file):
|
1265 |
+
if df is None:
|
1266 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1267 |
+
df = add_deberta_grade(df)
|
1268 |
+
min_grade = 0.3
|
1269 |
+
max_grade = np.inf
|
1270 |
+
before_rows = df.shape[0]
|
1271 |
+
df = df[df['grade_deberta'] >= min_grade]
|
1272 |
+
df = df[df['grade_deberta'] <= max_grade]
|
1273 |
+
after_rows = df.shape[0]
|
1274 |
+
print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
|
1275 |
+
print("After DeBERTa grade")
|
1276 |
+
print(df.describe())
|
1277 |
+
df.to_parquet(output_file, index=False)
|
1278 |
+
|
1279 |
+
file = output_file
|
1280 |
+
output_file = 'h2oGPT.cleaned.graded.human_bot.shorter.parquet'
|
1281 |
+
if df is None:
|
1282 |
+
df = pd.read_parquet(file).reset_index(drop=True)
|
1283 |
+
df.to_parquet(output_file, index=False)
|
1284 |
+
|
1285 |
+
|
1286 |
+
@pytest.mark.parametrize(
|
1287 |
+
"fixup_personality, only_personality, deberta_grading",
|
1288 |
+
[
|
1289 |
+
# [False, False, False],
|
1290 |
+
# [True, True, False],
|
1291 |
+
[True, False, False],
|
1292 |
+
# [True, False, True],
|
1293 |
+
]
|
1294 |
+
)
|
1295 |
+
@pytest.mark.parametrize("prompt_type", ["llama2"])
|
1296 |
+
def test_add_open_assistant(fixup_personality, only_personality, deberta_grading, prompt_type, save_json=True):
|
1297 |
+
"""
|
1298 |
+
Flatten tree structure into one row per path from root to leaf
|
1299 |
+
Also turn into human_bot prompting format:
|
1300 |
+
<human>: question\n<bot>: answer <human>: question2\n<bot>: answer2 Etc.
|
1301 |
+
Also saves a .json locally as side-effect
|
1302 |
+
returns list of dicts, containing intput, prompt_type and source
|
1303 |
+
"""
|
1304 |
+
from datasets import load_dataset
|
1305 |
+
data_file = "OpenAssistant/oasst1"
|
1306 |
+
ds = load_dataset(data_file)
|
1307 |
+
df = pd.concat([ds['train'].to_pandas(), ds['validation'].to_pandas()], axis=0)
|
1308 |
+
rows = {}
|
1309 |
+
message_ids = df['message_id'].values.tolist()
|
1310 |
+
message_tree_ids = df['message_tree_id'].values.tolist()
|
1311 |
+
parent_ids = df['parent_id'].values.tolist()
|
1312 |
+
texts = df['text'].values.tolist()
|
1313 |
+
roles = df['role'].values.tolist()
|
1314 |
+
deleteds = df['deleted'].values.tolist()
|
1315 |
+
for i in range(df.shape[0]):
|
1316 |
+
# collect all trees
|
1317 |
+
message_id = message_ids[i]
|
1318 |
+
message_tree_id = message_tree_ids[i]
|
1319 |
+
parent_id = parent_ids[i]
|
1320 |
+
text = texts[i]
|
1321 |
+
deleted = deleteds[i]
|
1322 |
+
if deleted:
|
1323 |
+
continue
|
1324 |
+
if fixup_personality:
|
1325 |
+
text = text.replace("Open Assistant", "h2oGPT")
|
1326 |
+
text = text.replace("Open-Assistant", "h2oGPT")
|
1327 |
+
text = text.replace("open-assistant", "h2oGPT")
|
1328 |
+
text = text.replace("OpenAssistant", "h2oGPT")
|
1329 |
+
text = text.replace("open assistant", "h2oGPT")
|
1330 |
+
text = text.replace("Open Assistand", "h2oGPT")
|
1331 |
+
text = text.replace("Open Assitant", "h2oGPT")
|
1332 |
+
text = text.replace("Open Assistent", "h2oGPT")
|
1333 |
+
text = text.replace("Open Assisstant", "h2oGPT")
|
1334 |
+
text = text.replace("Open Assitent", "h2oGPT")
|
1335 |
+
text = text.replace("Open Assitiant", "h2oGPT")
|
1336 |
+
text = text.replace("Open Assistiant", "h2oGPT")
|
1337 |
+
text = text.replace("Open Assitan ", "h2oGPT ")
|
1338 |
+
text = text.replace("Open Assistan ", "h2oGPT ")
|
1339 |
+
text = text.replace("Open Asistant", "h2oGPT")
|
1340 |
+
text = text.replace("Open Assiant", "h2oGPT")
|
1341 |
+
text = text.replace("Assistant", "h2oGPT")
|
1342 |
+
text = text.replace("LAION AI", "H2O.ai")
|
1343 |
+
text = text.replace("LAION-AI", "H2O.ai")
|
1344 |
+
text = text.replace("LAION,", "H2O.ai,")
|
1345 |
+
text = text.replace("LAION.ai", "H2O.ai")
|
1346 |
+
text = text.replace("LAION.", "H2O.ai.")
|
1347 |
+
text = text.replace("LAION", "H2O.ai")
|
1348 |
+
|
1349 |
+
role = roles[i]
|
1350 |
+
if prompt_type == "llama2":
|
1351 |
+
new_data = ('[INST] ' if role == 'prompter' else ' [/INST] ') + text
|
1352 |
+
if parent_id and role == 'prompter':
|
1353 |
+
new_data = " " + new_data
|
1354 |
+
elif prompt_type == "human_bot":
|
1355 |
+
new_data = ('<human>: ' if role == 'prompter' else '<bot>: ') + text
|
1356 |
+
else:
|
1357 |
+
raise NotImplementedError("prompt_type not supported")
|
1358 |
+
entry = dict(message_id=message_id, parent_id=parent_id, text=new_data)
|
1359 |
+
if message_tree_id not in rows:
|
1360 |
+
rows[message_tree_id] = [entry]
|
1361 |
+
else:
|
1362 |
+
rows[message_tree_id].append(entry)
|
1363 |
+
|
1364 |
+
all_rows = []
|
1365 |
+
|
1366 |
+
for node_id in rows:
|
1367 |
+
# order responses in tree, based on message/parent relationship
|
1368 |
+
conversations = []
|
1369 |
+
|
1370 |
+
list_msgs = rows[node_id]
|
1371 |
+
# find start
|
1372 |
+
while len(list_msgs):
|
1373 |
+
for i, leaf in enumerate(list_msgs):
|
1374 |
+
found = False
|
1375 |
+
parent_id = leaf['parent_id']
|
1376 |
+
if parent_id is None:
|
1377 |
+
# conversation starter
|
1378 |
+
conversations.append(leaf)
|
1379 |
+
found = True
|
1380 |
+
else:
|
1381 |
+
for conv in conversations:
|
1382 |
+
# find all conversations to add my message to
|
1383 |
+
if parent_id in conv['message_id'] and parent_id != conv['message_id'][-len(parent_id):]:
|
1384 |
+
# my message doesn't follow conversation
|
1385 |
+
continue
|
1386 |
+
if parent_id == conv['message_id'][-len(parent_id):]:
|
1387 |
+
# my message follows conversation, but fork first, so another follow-on message can do same
|
1388 |
+
conversations.append(conv.copy())
|
1389 |
+
if prompt_type == "llama2":
|
1390 |
+
conv['text'] += f"""{leaf['text']}"""
|
1391 |
+
elif prompt_type == "human_bot":
|
1392 |
+
conv['text'] += f"""
|
1393 |
+
{leaf['text']}
|
1394 |
+
"""
|
1395 |
+
else:
|
1396 |
+
raise NotImplementedError
|
1397 |
+
conv['message_id'] += leaf['message_id']
|
1398 |
+
found = True
|
1399 |
+
break
|
1400 |
+
if found:
|
1401 |
+
# my content was used, so nuke from list
|
1402 |
+
del list_msgs[i]
|
1403 |
+
break
|
1404 |
+
|
1405 |
+
# now reduce down to final conversations, find the longest chains of message ids
|
1406 |
+
for i, conv in enumerate(conversations):
|
1407 |
+
for j, conv2 in enumerate(conversations):
|
1408 |
+
if i == j:
|
1409 |
+
continue
|
1410 |
+
if conv['message_id'] and conv2['message_id']:
|
1411 |
+
assert conv['message_id'] != conv2['message_id']
|
1412 |
+
# delete the shorter conversation, if one contains the other
|
1413 |
+
if conv['message_id'] in conv2['message_id']:
|
1414 |
+
conv['message_id'] = None
|
1415 |
+
if conv2['message_id'] in conv['message_id']:
|
1416 |
+
conv2['message_id'] = None
|
1417 |
+
conversations = [c for c in conversations if c['message_id']]
|
1418 |
+
if only_personality:
|
1419 |
+
if prompt_type == "human_bot":
|
1420 |
+
all_rows.extend(
|
1421 |
+
[dict(input=c['text'] + "\n<human>:", output="", prompt_type='plain', source=data_file) for c in conversations if
|
1422 |
+
'h2oGPT' in c['text']])
|
1423 |
+
elif prompt_type == "llama2":
|
1424 |
+
all_rows.extend(
|
1425 |
+
[dict(input=c['text'] +
|
1426 |
+
("" if c['text'].rfind("[/INST]") > c['text'].rfind("[INST]") else " [/INST]"),
|
1427 |
+
output="", prompt_type='plain', source=data_file) for c in conversations if
|
1428 |
+
'h2oGPT' in c['text']])
|
1429 |
+
else:
|
1430 |
+
raise NotImplementedError
|
1431 |
+
else:
|
1432 |
+
if prompt_type == "human_bot":
|
1433 |
+
all_rows.extend(
|
1434 |
+
[dict(input=c['text'] + "\n<human>:", output="", prompt_type='plain', source=data_file) for c in conversations
|
1435 |
+
if
|
1436 |
+
"What is H2O.ai" not in c['text']])
|
1437 |
+
elif prompt_type == "llama2":
|
1438 |
+
all_rows.extend(
|
1439 |
+
[dict(input=c['text'] +
|
1440 |
+
(" " if c['text'].rfind("[/INST]") > c['text'].rfind("[INST]") else " [/INST]"),
|
1441 |
+
output="", prompt_type='plain', source=data_file) for c in conversations if
|
1442 |
+
"What is H2O.ai" not in c['text']])
|
1443 |
+
else:
|
1444 |
+
raise NotImplementedError
|
1445 |
+
|
1446 |
+
unhelpful = get_unhelpful_list()
|
1447 |
+
all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
|
1448 |
+
personality = create_personality_data(prompt_type=prompt_type)
|
1449 |
+
all_rows.extend(personality * 10)
|
1450 |
+
np.random.seed(123)
|
1451 |
+
np.random.shuffle(all_rows)
|
1452 |
+
print(len(all_rows))
|
1453 |
+
if deberta_grading:
|
1454 |
+
df = pd.DataFrame(all_rows)
|
1455 |
+
df = df.rename(columns={'input': 'text'})
|
1456 |
+
df = add_deberta_grade(df)
|
1457 |
+
df = df.rename(columns={'text': 'input'})
|
1458 |
+
drop = True
|
1459 |
+
if drop:
|
1460 |
+
min_grade = 0.3
|
1461 |
+
max_grade = np.inf
|
1462 |
+
before_rows = df.shape[0]
|
1463 |
+
df = df[df['grade_deberta'] >= min_grade]
|
1464 |
+
df = df[df['grade_deberta'] <= max_grade]
|
1465 |
+
after_rows = df.shape[0]
|
1466 |
+
print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
|
1467 |
+
print("After DeBERTa grade")
|
1468 |
+
print(df.describe())
|
1469 |
+
all_rows = []
|
1470 |
+
for i in range(df.shape[0]):
|
1471 |
+
all_rows.append(
|
1472 |
+
dict(
|
1473 |
+
input=df['input'].iloc[i],
|
1474 |
+
output=df['output'].iloc[i],
|
1475 |
+
source=df['source'].iloc[i],
|
1476 |
+
prompt_type=df['prompt_type'].iloc[i],
|
1477 |
+
grade_deberta=df['grade_deberta'].iloc[i],
|
1478 |
+
)
|
1479 |
+
)
|
1480 |
+
if save_json:
|
1481 |
+
data_file = data_file + \
|
1482 |
+
("_h2ogpt" if fixup_personality else "") + \
|
1483 |
+
("_only" if only_personality else "") + \
|
1484 |
+
("_graded" if deberta_grading else "") + \
|
1485 |
+
("_llama2_chat" if prompt_type == "llama2" else "")
|
1486 |
+
for i in range(len(all_rows)):
|
1487 |
+
all_rows[i]['id'] = i
|
1488 |
+
with open(data_file.lower().replace("/", "_") + ".json", "w") as f:
|
1489 |
+
f.write(json.dumps(all_rows, indent=2))
|
1490 |
+
return all_rows
|
1491 |
+
|
1492 |
+
|
1493 |
+
def test_finalize_to_json():
|
1494 |
+
df = pd.read_parquet('h2oGPT.cleaned.graded.human_bot.shorter.parquet')
|
1495 |
+
df = df.rename(columns={'text': 'input'})
|
1496 |
+
|
1497 |
+
print("Number of high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1498 |
+
|
1499 |
+
print("Adding open assistant data")
|
1500 |
+
with open("openassistant_oasst1_h2ogpt_graded.json") as f:
|
1501 |
+
open_assistant = json.loads(f.read())
|
1502 |
+
df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0)
|
1503 |
+
|
1504 |
+
def final_clean(df):
|
1505 |
+
from better_profanity import profanity
|
1506 |
+
profanity.load_censor_words_from_file("data/censor_words.txt")
|
1507 |
+
df['profanity'] = parallel_apply(
|
1508 |
+
df['input'],
|
1509 |
+
lambda x: profanity.contains_profanity(x),
|
1510 |
+
n_jobs=-1,
|
1511 |
+
)
|
1512 |
+
return df[(df['profanity'] == 0)].reset_index(drop=True)
|
1513 |
+
|
1514 |
+
print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1515 |
+
df = final_clean(df)
|
1516 |
+
print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
|
1517 |
+
print(df.describe())
|
1518 |
+
print(df.shape)
|
1519 |
+
row_list = []
|
1520 |
+
for i in range(df.shape[0]):
|
1521 |
+
row_list.append(
|
1522 |
+
dict(
|
1523 |
+
input=df.loc[i, 'input'],
|
1524 |
+
source=df.loc[i, 'source'],
|
1525 |
+
prompt_type='plain',
|
1526 |
+
)
|
1527 |
+
)
|
1528 |
+
np.random.seed(1234)
|
1529 |
+
np.random.shuffle(row_list)
|
1530 |
+
unhelpful = get_unhelpful_list()
|
1531 |
+
row_list = [x for x in row_list if not any(u in x['input'] for u in unhelpful)]
|
1532 |
+
for i in range(len(row_list)):
|
1533 |
+
row_list[i]['id'] = i
|
1534 |
+
row_list[i]['input'] = row_list[i]['input'].replace(" <bot>:", "\n<bot>:")
|
1535 |
+
with open('h2ogpt-oig-oasst1-instruct-cleaned-v3.json', "w") as f:
|
1536 |
+
f.write(json.dumps(row_list, indent=2))
|
1537 |
+
|
1538 |
+
|
1539 |
+
def create_personality_data(prompt_type="llama2"):
|
1540 |
+
questions = [
|
1541 |
+
"What's your name?",
|
1542 |
+
"What is your name?",
|
1543 |
+
"What are you?",
|
1544 |
+
"Who are you?",
|
1545 |
+
"Do you have a name?",
|
1546 |
+
"Who trained you?",
|
1547 |
+
"Who created you?",
|
1548 |
+
"Who made you?",
|
1549 |
+
]
|
1550 |
+
answers = [
|
1551 |
+
"I'm h2oGPT, a large language model by H2O.ai.",
|
1552 |
+
"I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1553 |
+
"My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1554 |
+
"My name is h2oGPT. I'm a large language model trained by H2O.ai.",
|
1555 |
+
"Hi! I'm h2oGPT, a large language model by H2O.ai.",
|
1556 |
+
"Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
|
1557 |
+
]
|
1558 |
+
help = [
|
1559 |
+
"",
|
1560 |
+
" How can I help you?",
|
1561 |
+
" How may I assist you?",
|
1562 |
+
" Nice to meet you.",
|
1563 |
+
]
|
1564 |
+
import itertools
|
1565 |
+
rows = []
|
1566 |
+
for pair in itertools.product(questions, answers, help):
|
1567 |
+
rows.append(
|
1568 |
+
dict(input=f"{pair[0]}", output=f"{pair[1]}{pair[2]}", prompt_type=prompt_type, source="H2O.ai")
|
1569 |
+
)
|
1570 |
+
for q, a in [
|
1571 |
+
("What is H2O.ai?", "H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models."),
|
1572 |
+
("What is h2o.ai?", "H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models."),
|
1573 |
+
("What is H2O?", "H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models."),
|
1574 |
+
("Who is h2o.ai?", "H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models."),
|
1575 |
+
("who is h2o.ai?", "H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models."),
|
1576 |
+
("who is h2o?", "H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models."),
|
1577 |
+
("what is H2O.ai?", "H2O.ai is the visionary leader in democratizing AI."),
|
1578 |
+
("who is H2O.ai?", "H2O.ai is the visionary leader in democratizing AI."),
|
1579 |
+
("who is H2O?", "H2O.ai is the visionary leader in democratizing AI."),
|
1580 |
+
("Who is h20?", "H2O.ai is the visionary leader in democratizing AI."),
|
1581 |
+
]:
|
1582 |
+
rows.append(dict(input=q, output=a, prompt_type=prompt_type, source='H2O.ai'))
|
1583 |
+
print(len(rows))
|
1584 |
+
with open("h2ogpt-personality.json", "w") as f:
|
1585 |
+
f.write(json.dumps(rows, indent=2))
|
1586 |
+
return rows
|
1587 |
+
|
1588 |
+
|
1589 |
+
def test_check_stats_data():
|
1590 |
+
filename = 'h2ogpt-oig-oasst1-instruct-cleaned-v3.json'
|
1591 |
+
df = pd.read_json(filename)
|
1592 |
+
|
1593 |
+
# get word stats
|
1594 |
+
df['char_count'] = df['input'].apply(lambda x: len(x))
|
1595 |
+
import matplotlib.pyplot as plt
|
1596 |
+
plt.figure(figsize=(10, 10))
|
1597 |
+
plt.hist(df['char_count'], bins=100)
|
1598 |
+
chars_avg = np.mean(df['char_count'])
|
1599 |
+
chars_median = np.median(df['char_count'])
|
1600 |
+
plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median))
|
1601 |
+
plt.savefig('chars_hist.png')
|
1602 |
+
plt.close()
|
1603 |
+
|
1604 |
+
# get tokenize stats for random sample of 1000 rows
|
1605 |
+
from finetune import generate_and_tokenize_prompt
|
1606 |
+
from loaders import get_loaders, get_tokenizer
|
1607 |
+
from functools import partial
|
1608 |
+
|
1609 |
+
llama_type = False
|
1610 |
+
tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
1611 |
+
model_loader, tokenizer_loader, conditional_type = (
|
1612 |
+
get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type))
|
1613 |
+
local_files_only = False
|
1614 |
+
resume_download = True
|
1615 |
+
use_auth_token = False
|
1616 |
+
tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
|
1617 |
+
prompt_type = 'plain' # trained with data already in human bot form
|
1618 |
+
train_on_inputs = True
|
1619 |
+
add_eos_token = False
|
1620 |
+
cutoff_len = 512 # can choose 2048
|
1621 |
+
generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
|
1622 |
+
train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
|
1623 |
+
cutoff_len=cutoff_len, tokenizer=tokenizer)
|
1624 |
+
from datasets import load_dataset
|
1625 |
+
data = load_dataset("json", data_files={"train": filename})
|
1626 |
+
val_set_size = 0.90
|
1627 |
+
train_val = data["train"].train_test_split(
|
1628 |
+
test_size=val_set_size, shuffle=True, seed=42
|
1629 |
+
)
|
1630 |
+
train_data = train_val["train"]
|
1631 |
+
train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count())
|
1632 |
+
|
1633 |
+
df_tokens = pd.DataFrame([len(x) for x in train_data['input_ids']], columns=['token_count'])
|
1634 |
+
|
1635 |
+
plt.figure(figsize=(10, 10))
|
1636 |
+
plt.hist(df_tokens['token_count'], bins=100)
|
1637 |
+
token_avg = np.mean(df_tokens['token_count'])
|
1638 |
+
token_median = np.median(df_tokens['token_count'])
|
1639 |
+
plt.title("token_count with cutoff=%s avg: %s median: %s" % (cutoff_len, token_avg, token_median))
|
1640 |
+
plt.savefig('token_hist_%s.png' % cutoff_len)
|
1641 |
+
plt.close()
|
1642 |
+
|
1643 |
+
|
1644 |
+
def get_unhelpful_list():
|
1645 |
+
# base versions
|
1646 |
+
unhelpful = ["I'm sorry, I didn't quite understand your question, could you please rephrase it?",
|
1647 |
+
"I'm sorry, but I don't understand your question. Could you please rephrase it?",
|
1648 |
+
"I'm sorry, I don't quite understand your question",
|
1649 |
+
"I'm sorry, I don't know",
|
1650 |
+
"I'm sorry, but I don't know",
|
1651 |
+
"I don't know anything",
|
1652 |
+
"I do not know",
|
1653 |
+
"I don't know",
|
1654 |
+
"I don't know how",
|
1655 |
+
"I do not know how",
|
1656 |
+
"Can you please explain what you mean",
|
1657 |
+
"please explain what you mean",
|
1658 |
+
"please explain",
|
1659 |
+
"I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by",
|
1660 |
+
"I'm sorry but I don't understand what you mean",
|
1661 |
+
"I don't understand",
|
1662 |
+
"I don't have the ability",
|
1663 |
+
"I do not have the ability",
|
1664 |
+
"I do not have",
|
1665 |
+
"I am a language model,",
|
1666 |
+
"I am a large language model,",
|
1667 |
+
"I do not understand your question. Can you please try to make it clearer?",
|
1668 |
+
"I'm sorry, but as an AI language model",
|
1669 |
+
"I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.",
|
1670 |
+
"I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?",
|
1671 |
+
"Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t",
|
1672 |
+
"I apologize, but I cannot perform the task you have requested.",
|
1673 |
+
"I'm sorry, I cannot perform this task as I am an AI language model and do not have access",
|
1674 |
+
"I'm sorry, I'm not sure what you're asking for here.",
|
1675 |
+
"I'm not sure what you are asking",
|
1676 |
+
"You need to provide more context",
|
1677 |
+
]
|
1678 |
+
# reduced versions, with redundant parts, just to give context for where they came from
|
1679 |
+
unhelpful += ["sorry, I didn't quite understand your question",
|
1680 |
+
"I didn't quite understand your question",
|
1681 |
+
"I didn't understand your question",
|
1682 |
+
"I did not understand your question",
|
1683 |
+
"I did not understand the question",
|
1684 |
+
"could you please rephrase"
|
1685 |
+
"could you rephrase"
|
1686 |
+
"I do not understand your question.",
|
1687 |
+
"I do not understand the question.",
|
1688 |
+
"I do not understand that question.",
|
1689 |
+
"Can you please try to make it clearer",
|
1690 |
+
"Can you try to make it clearer",
|
1691 |
+
"sorry, but as an AI language model",
|
1692 |
+
"as an AI language model",
|
1693 |
+
"I apologize, but I cannot",
|
1694 |
+
"I cannot rephrase text",
|
1695 |
+
"I cannot understand. Your post is difficult to read and follow."
|
1696 |
+
"Your post is difficult to read and follow."
|
1697 |
+
"I apologize, but I am",
|
1698 |
+
"Sorry, but I am not ",
|
1699 |
+
"nor am I capable",
|
1700 |
+
"I am not capable of",
|
1701 |
+
"I apologize, but I cannot perform the task you have requested",
|
1702 |
+
"I cannot perform the task",
|
1703 |
+
"I cannot complete the task",
|
1704 |
+
"I'm sorry",
|
1705 |
+
"I am sorry",
|
1706 |
+
"do not have access",
|
1707 |
+
"not sure what you're asking for",
|
1708 |
+
"not sure what you are asking for",
|
1709 |
+
"not sure what is being asked",
|
1710 |
+
"I'm not sure what you are asking",
|
1711 |
+
"not sure what you are asking",
|
1712 |
+
"You need to provide more context",
|
1713 |
+
"provide more context",
|
1714 |
+
]
|
1715 |
+
unhelpful += ["As a large language model",
|
1716 |
+
"cannot provide any information",
|
1717 |
+
"As an artificial intelligence I do not have the capability",
|
1718 |
+
"As an artificial intelligence I don't have the capability",
|
1719 |
+
"As an artificial intelligence I can't",
|
1720 |
+
"As an artificial intelligence I cannot",
|
1721 |
+
"I am sorry but I do not understand",
|
1722 |
+
"Can you please explain",
|
1723 |
+
"(sorry couldn't resist)",
|
1724 |
+
"(sorry could not resist)",
|
1725 |
+
" :)",
|
1726 |
+
" ;)",
|
1727 |
+
" :-)",
|
1728 |
+
" ;-)",
|
1729 |
+
" lol ",
|
1730 |
+
"Thanks so much!!!",
|
1731 |
+
"Thank You :)!!!",
|
1732 |
+
"Please try not to repeat",
|
1733 |
+
"I am an AI language model",
|
1734 |
+
"I'm a AI assistant that",
|
1735 |
+
"I'm an AI assistant that",
|
1736 |
+
"I am an AI assistant that",
|
1737 |
+
"etc.",
|
1738 |
+
"etc.etc.",
|
1739 |
+
"etc. etc.",
|
1740 |
+
"etc etc",
|
1741 |
+
]
|
1742 |
+
return unhelpful
|
1743 |
+
|
1744 |
+
|
1745 |
+
def test_check_unhelpful():
|
1746 |
+
# file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json'
|
1747 |
+
file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json'
|
1748 |
+
# file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
|
1749 |
+
|
1750 |
+
unhelpful = get_unhelpful_list()
|
1751 |
+
# data = json.load(open(file, 'rt'))
|
1752 |
+
df = pd.read_json(file)
|
1753 |
+
|
1754 |
+
use_reward_score_threshold = False
|
1755 |
+
use_bleu_threshold = False
|
1756 |
+
use_sentence_sim = True
|
1757 |
+
|
1758 |
+
from sacrebleu.metrics import BLEU
|
1759 |
+
bleu = BLEU()
|
1760 |
+
from nltk.translate.bleu_score import sentence_bleu
|
1761 |
+
|
1762 |
+
def get_bleu(actual, expected_list):
|
1763 |
+
# return bleu.sentence_score(actual, expected_list).score
|
1764 |
+
return sentence_bleu(expected_list, actual)
|
1765 |
+
|
1766 |
+
threshold = 0.0
|
1767 |
+
if use_reward_score_threshold:
|
1768 |
+
df = df[df['grade_deberta'] > threshold]
|
1769 |
+
|
1770 |
+
# back to as if original json load
|
1771 |
+
data = df.to_dict(orient='records')
|
1772 |
+
bads = {}
|
1773 |
+
string_all = str(data)
|
1774 |
+
for sub in unhelpful:
|
1775 |
+
bads[sub] = string_all.count(sub)
|
1776 |
+
bads = {k: v for k, v in bads.items() if v > 0}
|
1777 |
+
import pprint
|
1778 |
+
pp = pprint.PrettyPrinter(indent=4)
|
1779 |
+
pp.pprint(bads)
|
1780 |
+
|
1781 |
+
total_bads = sum(list(bads.values()))
|
1782 |
+
print('total_bads: %s' % total_bads, flush=True)
|
1783 |
+
|
1784 |
+
# check just bot
|
1785 |
+
import re
|
1786 |
+
convs = [[x.strip() for x in re.split(r'%s|%s' % (human, bot), y['input']) if x.strip()] for y in data]
|
1787 |
+
humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs]
|
1788 |
+
bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs]
|
1789 |
+
|
1790 |
+
# FIXME: apply back to json etc., just see for now
|
1791 |
+
bleu_threshold = 0.9
|
1792 |
+
if use_bleu_threshold:
|
1793 |
+
bots = [[x for x in y if get_bleu(x, unhelpful) < bleu_threshold] for y in tqdm(bots)]
|
1794 |
+
|
1795 |
+
cosine_sim_threshold = 0.8
|
1796 |
+
if use_sentence_sim:
|
1797 |
+
# pip install sentence_transformers-2.2.2
|
1798 |
+
from sentence_transformers import SentenceTransformer
|
1799 |
+
# sent_model = 'bert-base-nli-mean-tokens'
|
1800 |
+
# sent_model = 'nli-distilroberta-base-v2'
|
1801 |
+
sent_model = 'all-MiniLM-L6-v2'
|
1802 |
+
model = SentenceTransformer(sent_model)
|
1803 |
+
sentence_embeddings = model.encode(unhelpful)
|
1804 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
1805 |
+
bots = [x for x in tqdm(bots) if
|
1806 |
+
np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
|
1807 |
+
|
1808 |
+
bads_bots = {}
|
1809 |
+
string_all = str(bots)
|
1810 |
+
for sub in unhelpful:
|
1811 |
+
bads_bots[sub] = string_all.count(sub)
|
1812 |
+
bads_bots = {k: v for k, v in bads_bots.items() if v > 0}
|
1813 |
+
import pprint
|
1814 |
+
pp = pprint.PrettyPrinter(indent=4)
|
1815 |
+
pp.pprint(bads_bots)
|
1816 |
+
|
1817 |
+
total_bads_bots = sum(list(bads_bots.values()))
|
1818 |
+
print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (
|
1819 |
+
threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
|
1820 |
+
|
1821 |
+
# assert len(bads) == 0, bads
|
1822 |
+
assert len(bads_bots) == 0, bads_bots
|
1823 |
+
|
1824 |
+
|
1825 |
+
def test_fortune2000_personalized():
|
1826 |
+
row_list = []
|
1827 |
+
import glob
|
1828 |
+
if not os.path.isdir("wikitext"):
|
1829 |
+
raise RuntimeError("download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip")
|
1830 |
+
for file in glob.glob("wikitext/*.txt"):
|
1831 |
+
with open(file, "r") as f:
|
1832 |
+
blob = f.read()
|
1833 |
+
N = 512 * 4
|
1834 |
+
row_list.extend([{'input': s, 'prompt_type': 'plain', 'source': "%s" % os.path.basename(file)}
|
1835 |
+
for s in get_sentences(blob, N) if s])
|
1836 |
+
personality = create_personality_data()
|
1837 |
+
import copy
|
1838 |
+
for i in range(10):
|
1839 |
+
row_list.extend(copy.deepcopy(personality))
|
1840 |
+
np.random.seed(123)
|
1841 |
+
np.random.shuffle(row_list)
|
1842 |
+
for i in range(len(row_list)):
|
1843 |
+
row_list[i]['id'] = i
|
1844 |
+
for i in range(len(row_list)):
|
1845 |
+
assert row_list[i]['id'] == i
|
1846 |
+
with open("h2ogpt-fortune2000-personalized.json", "w") as ff:
|
1847 |
+
ff.write(json.dumps(row_list, indent=2))
|
src/db_utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
|
3 |
+
from enums import LangChainMode
|
4 |
+
|
5 |
+
|
6 |
+
def set_userid(db1s, requests_state1, get_userid_auth):
|
7 |
+
db1 = db1s[LangChainMode.MY_DATA.value]
|
8 |
+
assert db1 is not None and len(db1) == length_db1()
|
9 |
+
if not db1[1]:
|
10 |
+
db1[1] = get_userid_auth(requests_state1)
|
11 |
+
if not db1[2]:
|
12 |
+
username1 = None
|
13 |
+
if 'username' in requests_state1:
|
14 |
+
username1 = requests_state1['username']
|
15 |
+
db1[2] = username1
|
16 |
+
|
17 |
+
|
18 |
+
def set_userid_direct(db1s, userid, username):
|
19 |
+
db1 = db1s[LangChainMode.MY_DATA.value]
|
20 |
+
db1[1] = userid
|
21 |
+
db1[2] = username
|
22 |
+
|
23 |
+
|
24 |
+
def get_userid_direct(db1s):
|
25 |
+
return db1s[LangChainMode.MY_DATA.value][1] if db1s is not None else ''
|
26 |
+
|
27 |
+
|
28 |
+
def get_username_direct(db1s):
|
29 |
+
return db1s[LangChainMode.MY_DATA.value][2] if db1s is not None else ''
|
30 |
+
|
31 |
+
|
32 |
+
def get_dbid(db1):
|
33 |
+
return db1[1]
|
34 |
+
|
35 |
+
|
36 |
+
def set_dbid(db1):
|
37 |
+
# can only call this after function called so for specific user, not in gr.State() that occurs during app init
|
38 |
+
assert db1 is not None and len(db1) == length_db1()
|
39 |
+
if db1[1] is None:
|
40 |
+
# uuid in db is used as user ID
|
41 |
+
db1[1] = str(uuid.uuid4())
|
42 |
+
|
43 |
+
|
44 |
+
def length_db1():
|
45 |
+
# For MyData:
|
46 |
+
# 0: db
|
47 |
+
# 1: userid and dbid
|
48 |
+
# 2: username
|
49 |
+
|
50 |
+
# For others:
|
51 |
+
# 0: db
|
52 |
+
# 1: dbid
|
53 |
+
# 2: None
|
54 |
+
return 3
|
src/enums.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class PromptType(Enum):
|
5 |
+
custom = -1
|
6 |
+
plain = 0
|
7 |
+
instruct = 1
|
8 |
+
quality = 2
|
9 |
+
human_bot = 3
|
10 |
+
dai_faq = 4
|
11 |
+
summarize = 5
|
12 |
+
simple_instruct = 6
|
13 |
+
instruct_vicuna = 7
|
14 |
+
instruct_with_end = 8
|
15 |
+
human_bot_orig = 9
|
16 |
+
prompt_answer = 10
|
17 |
+
open_assistant = 11
|
18 |
+
wizard_lm = 12
|
19 |
+
wizard_mega = 13
|
20 |
+
instruct_vicuna2 = 14
|
21 |
+
instruct_vicuna3 = 15
|
22 |
+
wizard2 = 16
|
23 |
+
wizard3 = 17
|
24 |
+
instruct_simple = 18
|
25 |
+
wizard_vicuna = 19
|
26 |
+
openai = 20
|
27 |
+
openai_chat = 21
|
28 |
+
gptj = 22
|
29 |
+
prompt_answer_openllama = 23
|
30 |
+
vicuna11 = 24
|
31 |
+
mptinstruct = 25
|
32 |
+
mptchat = 26
|
33 |
+
falcon = 27
|
34 |
+
guanaco = 28
|
35 |
+
llama2 = 29
|
36 |
+
beluga = 30
|
37 |
+
wizard3nospace = 31
|
38 |
+
one_shot = 32
|
39 |
+
falcon_chat = 33
|
40 |
+
|
41 |
+
|
42 |
+
class DocumentSubset(Enum):
|
43 |
+
Relevant = 0
|
44 |
+
RelSources = 1
|
45 |
+
TopKSources = 2
|
46 |
+
|
47 |
+
|
48 |
+
non_query_commands = [
|
49 |
+
DocumentSubset.RelSources.name,
|
50 |
+
DocumentSubset.TopKSources.name
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
class DocumentChoice(Enum):
|
55 |
+
ALL = 'All'
|
56 |
+
|
57 |
+
|
58 |
+
class LangChainMode(Enum):
|
59 |
+
"""LangChain mode"""
|
60 |
+
|
61 |
+
DISABLED = "Disabled"
|
62 |
+
LLM = "LLM"
|
63 |
+
WIKI = "wiki"
|
64 |
+
WIKI_FULL = "wiki_full"
|
65 |
+
USER_DATA = "UserData"
|
66 |
+
MY_DATA = "MyData"
|
67 |
+
GITHUB_H2OGPT = "github h2oGPT"
|
68 |
+
H2O_DAI_DOCS = "DriverlessAI docs"
|
69 |
+
|
70 |
+
|
71 |
+
class LangChainTypes(Enum):
|
72 |
+
SHARED = 'shared'
|
73 |
+
PERSONAL = 'personal'
|
74 |
+
EITHER = 'either' # used when user did not pass which one, so need to try both
|
75 |
+
|
76 |
+
|
77 |
+
# modes should not be removed from visible list or added by name
|
78 |
+
langchain_modes_intrinsic = [LangChainMode.DISABLED.value,
|
79 |
+
LangChainMode.LLM.value,
|
80 |
+
LangChainMode.MY_DATA.value]
|
81 |
+
|
82 |
+
langchain_modes_non_db = [LangChainMode.DISABLED.value,
|
83 |
+
LangChainMode.LLM.value]
|
84 |
+
|
85 |
+
|
86 |
+
class LangChainAction(Enum):
|
87 |
+
"""LangChain action"""
|
88 |
+
|
89 |
+
QUERY = "Query"
|
90 |
+
# WIP:
|
91 |
+
# SUMMARIZE_MAP = "Summarize_map_reduce"
|
92 |
+
SUMMARIZE_MAP = "Summarize"
|
93 |
+
SUMMARIZE_ALL = "Summarize_all"
|
94 |
+
SUMMARIZE_REFINE = "Summarize_refine"
|
95 |
+
|
96 |
+
|
97 |
+
class LangChainAgent(Enum):
|
98 |
+
"""LangChain agents"""
|
99 |
+
|
100 |
+
SEARCH = "Search"
|
101 |
+
COLLECTION = "Collection"
|
102 |
+
PYTHON = "Python"
|
103 |
+
CSV = "CSV"
|
104 |
+
PANDAS = "Pandas"
|
105 |
+
JSON = 'JSON'
|
106 |
+
|
107 |
+
|
108 |
+
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
109 |
+
|
110 |
+
# from site-packages/langchain/llms/openai.py
|
111 |
+
# but needed since ChatOpenAI doesn't have this information
|
112 |
+
model_token_mapping = {
|
113 |
+
"gpt-4": 8192,
|
114 |
+
"gpt-4-0314": 8192,
|
115 |
+
"gpt-4-32k": 32768,
|
116 |
+
"gpt-4-32k-0314": 32768,
|
117 |
+
"gpt-3.5-turbo": 4096,
|
118 |
+
"gpt-3.5-turbo-16k": 16 * 1024,
|
119 |
+
"gpt-3.5-turbo-0301": 4096,
|
120 |
+
"text-ada-001": 2049,
|
121 |
+
"ada": 2049,
|
122 |
+
"text-babbage-001": 2040,
|
123 |
+
"babbage": 2049,
|
124 |
+
"text-curie-001": 2049,
|
125 |
+
"curie": 2049,
|
126 |
+
"davinci": 2049,
|
127 |
+
"text-davinci-003": 4097,
|
128 |
+
"text-davinci-002": 4097,
|
129 |
+
"code-davinci-002": 8001,
|
130 |
+
"code-davinci-001": 8001,
|
131 |
+
"code-cushman-002": 2048,
|
132 |
+
"code-cushman-001": 2048,
|
133 |
+
}
|
134 |
+
|
135 |
+
font_size = 2
|
136 |
+
head_acc = 40 # 40 for 6-way
|
137 |
+
source_prefix = "Sources [Score | Link]:"
|
138 |
+
source_postfix = "End Sources<p>"
|
139 |
+
|
140 |
+
super_source_prefix = f"""<details><summary><font size="{font_size}">Sources</font></summary><font size="{font_size}"><font size="{font_size}">Sources [Score | Link]:"""
|
141 |
+
super_source_postfix = f"""End Sources<p></font></font></details>"""
|
142 |
+
|
143 |
+
|
144 |
+
def t5_type(model_name):
|
145 |
+
return 't5' == model_name.lower() or \
|
146 |
+
't5-' in model_name.lower() or \
|
147 |
+
'flan-' in model_name.lower() or \
|
148 |
+
'fastchat-t5' in model_name.lower()
|
149 |
+
|
150 |
+
|
151 |
+
def get_langchain_prompts(pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary,
|
152 |
+
model_name, inference_server, model_path_llama):
|
153 |
+
if model_name and ('falcon' in model_name or
|
154 |
+
'Llama-2'.lower() in model_name.lower() or
|
155 |
+
model_path_llama and 'llama-2' in model_path_llama.lower()) or \
|
156 |
+
model_name in [None, '']:
|
157 |
+
# use when no model, like no --base_model
|
158 |
+
pre_prompt_query1 = "Pay attention and remember the information below, which will help to answer the question or imperative after the context ends.\n"
|
159 |
+
prompt_query1 = "According to only the information in the document sources provided within the context above, "
|
160 |
+
elif inference_server and inference_server.startswith('openai'):
|
161 |
+
pre_prompt_query1 = "Pay attention and remember the information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents.\n"
|
162 |
+
prompt_query1 = "According to (primarily) the information in the document sources provided within context above, "
|
163 |
+
else:
|
164 |
+
pre_prompt_query1 = ""
|
165 |
+
prompt_query1 = ""
|
166 |
+
|
167 |
+
pre_prompt_summary1 = """In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text\n"""
|
168 |
+
prompt_summary1 = "Using only the information in the document sources above, write a condensed and concise summary of key results (preferably as bullet points):\n"
|
169 |
+
|
170 |
+
if pre_prompt_query is None:
|
171 |
+
pre_prompt_query = pre_prompt_query1
|
172 |
+
if prompt_query is None:
|
173 |
+
prompt_query = prompt_query1
|
174 |
+
if pre_prompt_summary is None:
|
175 |
+
pre_prompt_summary = pre_prompt_summary1
|
176 |
+
if prompt_summary is None:
|
177 |
+
prompt_summary = prompt_summary1
|
178 |
+
|
179 |
+
return pre_prompt_query, prompt_query, pre_prompt_summary, prompt_summary
|
180 |
+
|
181 |
+
|
182 |
+
def gr_to_lg(image_loaders,
|
183 |
+
pdf_loaders,
|
184 |
+
url_loaders,
|
185 |
+
**kwargs,
|
186 |
+
):
|
187 |
+
if image_loaders is None:
|
188 |
+
image_loaders = kwargs['image_loaders_options0']
|
189 |
+
if pdf_loaders is None:
|
190 |
+
pdf_loaders = kwargs['pdf_loaders_options0']
|
191 |
+
if url_loaders is None:
|
192 |
+
url_loaders = kwargs['url_loaders_options0']
|
193 |
+
# translate:
|
194 |
+
# 'auto' wouldn't be used here
|
195 |
+
ret = dict(
|
196 |
+
# urls
|
197 |
+
use_unstructured='Unstructured' in url_loaders,
|
198 |
+
use_playwright='PlayWright' in url_loaders,
|
199 |
+
use_selenium='Selenium' in url_loaders,
|
200 |
+
|
201 |
+
# pdfs
|
202 |
+
use_pymupdf='on' if 'PyMuPDF' in pdf_loaders else 'off',
|
203 |
+
use_unstructured_pdf='on' if 'Unstructured' in pdf_loaders else 'off',
|
204 |
+
use_pypdf='on' if 'PyPDF' in pdf_loaders else 'off',
|
205 |
+
enable_pdf_ocr='on' if 'OCR' in pdf_loaders else 'off',
|
206 |
+
enable_pdf_doctr='on' if 'DocTR' in pdf_loaders else 'off',
|
207 |
+
try_pdf_as_html='on' if 'TryHTML' in pdf_loaders else 'off',
|
208 |
+
|
209 |
+
# images
|
210 |
+
enable_ocr='OCR' in image_loaders,
|
211 |
+
enable_doctr='DocTR' in image_loaders,
|
212 |
+
enable_pix2struct='Pix2Struct' in image_loaders,
|
213 |
+
enable_captions='Caption' in image_loaders or 'CaptionBlip2' in image_loaders,
|
214 |
+
)
|
215 |
+
if 'CaptionBlip2' in image_loaders:
|
216 |
+
# just override, don't actually do both even if user chose both
|
217 |
+
captions_model = "Salesforce/blip2-flan-t5-xl"
|
218 |
+
else:
|
219 |
+
captions_model = kwargs['captions_model']
|
220 |
+
return ret, captions_model
|
221 |
+
|
222 |
+
|
223 |
+
invalid_key_msg = 'Invalid Access Key, request access key from sales@h2o.ai or jon.mckinney@h2o.ai'
|
224 |
+
|
225 |
+
docs_ordering_types = ['best_first', 'best_near_prompt', 'reverse_ucurve_sort']
|
src/evaluate_params.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_args_list = ['model_state', 'my_db_state', 'selection_docs_state', 'requests_state']
|
2 |
+
|
3 |
+
no_default_param_names = [
|
4 |
+
'instruction',
|
5 |
+
'iinput',
|
6 |
+
'context',
|
7 |
+
'instruction_nochat',
|
8 |
+
'iinput_nochat',
|
9 |
+
]
|
10 |
+
|
11 |
+
gen_hyper0 = ['num_beams',
|
12 |
+
'max_new_tokens',
|
13 |
+
'min_new_tokens',
|
14 |
+
'early_stopping',
|
15 |
+
'max_time',
|
16 |
+
'repetition_penalty',
|
17 |
+
'num_return_sequences',
|
18 |
+
'do_sample',
|
19 |
+
]
|
20 |
+
gen_hyper = ['temperature',
|
21 |
+
'top_p',
|
22 |
+
'top_k'] + gen_hyper0
|
23 |
+
reader_names = ['image_loaders', 'pdf_loaders', 'url_loaders', 'jq_schema']
|
24 |
+
|
25 |
+
eval_func_param_names = ['instruction',
|
26 |
+
'iinput',
|
27 |
+
'context',
|
28 |
+
'stream_output',
|
29 |
+
'prompt_type',
|
30 |
+
'prompt_dict'] + \
|
31 |
+
gen_hyper + \
|
32 |
+
['chat',
|
33 |
+
'instruction_nochat',
|
34 |
+
'iinput_nochat',
|
35 |
+
'langchain_mode',
|
36 |
+
'add_chat_history_to_context',
|
37 |
+
'langchain_action',
|
38 |
+
'langchain_agents',
|
39 |
+
'top_k_docs',
|
40 |
+
'chunk',
|
41 |
+
'chunk_size',
|
42 |
+
'document_subset',
|
43 |
+
'document_choice',
|
44 |
+
'pre_prompt_query',
|
45 |
+
'prompt_query',
|
46 |
+
'pre_prompt_summary',
|
47 |
+
'prompt_summary',
|
48 |
+
'system_prompt',
|
49 |
+
] + \
|
50 |
+
reader_names + \
|
51 |
+
['visible_models',
|
52 |
+
'h2ogpt_key',
|
53 |
+
'add_search_to_context',
|
54 |
+
'chat_conversation',
|
55 |
+
'text_context_list',
|
56 |
+
'docs_ordering_type',
|
57 |
+
'min_max_new_tokens',
|
58 |
+
]
|
59 |
+
|
60 |
+
# form evaluate defaults for submit_nochat_api
|
61 |
+
eval_func_param_names_defaults = eval_func_param_names.copy()
|
62 |
+
for k in no_default_param_names:
|
63 |
+
if k in eval_func_param_names_defaults:
|
64 |
+
eval_func_param_names_defaults.remove(k)
|
65 |
+
|
66 |
+
eval_extra_columns = ['prompt', 'response', 'score']
|
67 |
+
|
68 |
+
# override default_kwargs if user_kwargs None for args evaluate() uses that are not just in model_state
|
69 |
+
# ensure prompt_type consistent with prep_bot(), so nochat API works same way
|
70 |
+
# see how default_kwargs is set in gradio_runner.py
|
71 |
+
key_overrides = ['prompt_type', 'prompt_dict']
|
src/gen.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/gpt4all_llm.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
from typing import Dict, Any, Optional, List, Iterator
|
4 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
5 |
+
from langchain.schema.output import GenerationChunk
|
6 |
+
from pydantic import root_validator
|
7 |
+
from langchain.llms import gpt4all
|
8 |
+
|
9 |
+
from utils import FakeTokenizer, get_ngpus_vis, url_alive, download_simple
|
10 |
+
|
11 |
+
|
12 |
+
def get_model_tokenizer_gpt4all(base_model, n_jobs=None, max_seq_len=None, llamacpp_dict=None):
|
13 |
+
assert llamacpp_dict is not None
|
14 |
+
# defaults (some of these are generation parameters, so need to be passed in at generation time)
|
15 |
+
model_name = base_model.lower()
|
16 |
+
model = get_llm_gpt4all(model_name, model=None,
|
17 |
+
# max_new_tokens=max_new_tokens,
|
18 |
+
# temperature=temperature,
|
19 |
+
# repetition_penalty=repetition_penalty,
|
20 |
+
# top_k=top_k,
|
21 |
+
# top_p=top_p,
|
22 |
+
# callbacks=callbacks,
|
23 |
+
n_jobs=n_jobs,
|
24 |
+
# verbose=verbose,
|
25 |
+
# streaming=stream_output,
|
26 |
+
# prompter=prompter,
|
27 |
+
# context=context,
|
28 |
+
# iinput=iinput,
|
29 |
+
inner_class=True,
|
30 |
+
max_seq_len=max_seq_len,
|
31 |
+
llamacpp_dict=llamacpp_dict,
|
32 |
+
)
|
33 |
+
return model, FakeTokenizer(), 'cpu'
|
34 |
+
|
35 |
+
|
36 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
37 |
+
|
38 |
+
|
39 |
+
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
40 |
+
|
41 |
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
42 |
+
"""Run on new LLM token. Only available when streaming is enabled."""
|
43 |
+
# streaming to std already occurs without this
|
44 |
+
# sys.stdout.write(token)
|
45 |
+
# sys.stdout.flush()
|
46 |
+
pass
|
47 |
+
|
48 |
+
|
49 |
+
def get_model_kwargs(llamacpp_dict, default_kwargs, cls, exclude_list=[]):
|
50 |
+
# default from class
|
51 |
+
model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
|
52 |
+
# from our defaults
|
53 |
+
model_kwargs.update(default_kwargs)
|
54 |
+
# from user defaults
|
55 |
+
model_kwargs.update(llamacpp_dict)
|
56 |
+
# ensure only valid keys
|
57 |
+
func_names = list(inspect.signature(cls).parameters)
|
58 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
59 |
+
# make int or float if can to satisfy types for class
|
60 |
+
for k, v in model_kwargs.items():
|
61 |
+
try:
|
62 |
+
if float(v) == int(v):
|
63 |
+
model_kwargs[k] = int(v)
|
64 |
+
else:
|
65 |
+
model_kwargs[k] = float(v)
|
66 |
+
except:
|
67 |
+
pass
|
68 |
+
return model_kwargs
|
69 |
+
|
70 |
+
|
71 |
+
def get_gpt4all_default_kwargs(max_new_tokens=256,
|
72 |
+
temperature=0.1,
|
73 |
+
repetition_penalty=1.0,
|
74 |
+
top_k=40,
|
75 |
+
top_p=0.7,
|
76 |
+
n_jobs=None,
|
77 |
+
verbose=False,
|
78 |
+
max_seq_len=None,
|
79 |
+
):
|
80 |
+
if n_jobs in [None, -1]:
|
81 |
+
n_jobs = int(os.getenv('OMP_NUM_THREADS', str(os.cpu_count()//2)))
|
82 |
+
n_jobs = max(1, min(20, n_jobs)) # hurts beyond some point
|
83 |
+
n_gpus = get_ngpus_vis()
|
84 |
+
default_kwargs = dict(context_erase=0.5,
|
85 |
+
n_batch=1,
|
86 |
+
max_tokens=max_seq_len - max_new_tokens,
|
87 |
+
n_predict=max_new_tokens,
|
88 |
+
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
89 |
+
repeat_penalty=repetition_penalty,
|
90 |
+
temp=temperature,
|
91 |
+
temperature=temperature,
|
92 |
+
top_k=top_k,
|
93 |
+
top_p=top_p,
|
94 |
+
use_mlock=True,
|
95 |
+
n_ctx=max_seq_len,
|
96 |
+
n_threads=n_jobs,
|
97 |
+
verbose=verbose)
|
98 |
+
if n_gpus != 0:
|
99 |
+
default_kwargs.update(dict(n_gpu_layers=100))
|
100 |
+
return default_kwargs
|
101 |
+
|
102 |
+
|
103 |
+
def get_llm_gpt4all(model_name,
|
104 |
+
model=None,
|
105 |
+
max_new_tokens=256,
|
106 |
+
temperature=0.1,
|
107 |
+
repetition_penalty=1.0,
|
108 |
+
top_k=40,
|
109 |
+
top_p=0.7,
|
110 |
+
streaming=False,
|
111 |
+
callbacks=None,
|
112 |
+
prompter=None,
|
113 |
+
context='',
|
114 |
+
iinput='',
|
115 |
+
n_jobs=None,
|
116 |
+
verbose=False,
|
117 |
+
inner_class=False,
|
118 |
+
max_seq_len=None,
|
119 |
+
llamacpp_dict=None,
|
120 |
+
):
|
121 |
+
if not inner_class:
|
122 |
+
assert prompter is not None
|
123 |
+
|
124 |
+
default_kwargs = \
|
125 |
+
get_gpt4all_default_kwargs(max_new_tokens=max_new_tokens,
|
126 |
+
temperature=temperature,
|
127 |
+
repetition_penalty=repetition_penalty,
|
128 |
+
top_k=top_k,
|
129 |
+
top_p=top_p,
|
130 |
+
n_jobs=n_jobs,
|
131 |
+
verbose=verbose,
|
132 |
+
max_seq_len=max_seq_len,
|
133 |
+
)
|
134 |
+
if model_name == 'llama':
|
135 |
+
cls = H2OLlamaCpp
|
136 |
+
if model is None:
|
137 |
+
llamacpp_dict = llamacpp_dict.copy()
|
138 |
+
model_path = llamacpp_dict.pop('model_path_llama')
|
139 |
+
if os.path.isfile(os.path.basename(model_path)):
|
140 |
+
# e.g. if offline but previously downloaded
|
141 |
+
model_path = os.path.basename(model_path)
|
142 |
+
elif url_alive(model_path):
|
143 |
+
# online
|
144 |
+
ggml_path = os.getenv('GGML_PATH')
|
145 |
+
dest = os.path.join(ggml_path, os.path.basename(model_path)) if ggml_path else None
|
146 |
+
model_path = download_simple(model_path, dest=dest)
|
147 |
+
else:
|
148 |
+
model_path = model
|
149 |
+
model_kwargs = get_model_kwargs(llamacpp_dict, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
150 |
+
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
|
151 |
+
prompter=prompter, context=context, iinput=iinput))
|
152 |
+
|
153 |
+
# migration to new langchain fix:
|
154 |
+
odd_keys = ['model_kwargs', 'grammar_path', 'grammar']
|
155 |
+
for key in odd_keys:
|
156 |
+
model_kwargs.pop(key, None)
|
157 |
+
|
158 |
+
llm = cls(**model_kwargs)
|
159 |
+
llm.client.verbose = verbose
|
160 |
+
inner_model = llm.client
|
161 |
+
elif model_name == 'gpt4all_llama':
|
162 |
+
cls = H2OGPT4All
|
163 |
+
if model is None:
|
164 |
+
llamacpp_dict = llamacpp_dict.copy()
|
165 |
+
model_path = llamacpp_dict.pop('model_name_gpt4all_llama')
|
166 |
+
if url_alive(model_path):
|
167 |
+
# online
|
168 |
+
ggml_path = os.getenv('GGML_PATH')
|
169 |
+
dest = os.path.join(ggml_path, os.path.basename(model_path)) if ggml_path else None
|
170 |
+
model_path = download_simple(model_path, dest=dest)
|
171 |
+
else:
|
172 |
+
model_path = model
|
173 |
+
model_kwargs = get_model_kwargs(llamacpp_dict, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
174 |
+
model_kwargs.update(
|
175 |
+
dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
|
176 |
+
prompter=prompter, context=context, iinput=iinput))
|
177 |
+
llm = cls(**model_kwargs)
|
178 |
+
inner_model = llm.client
|
179 |
+
elif model_name == 'gptj':
|
180 |
+
cls = H2OGPT4All
|
181 |
+
if model is None:
|
182 |
+
llamacpp_dict = llamacpp_dict.copy()
|
183 |
+
model_path = llamacpp_dict.pop('model_name_gptj') if model is None else model
|
184 |
+
if url_alive(model_path):
|
185 |
+
ggml_path = os.getenv('GGML_PATH')
|
186 |
+
dest = os.path.join(ggml_path, os.path.basename(model_path)) if ggml_path else None
|
187 |
+
model_path = download_simple(model_path, dest=dest)
|
188 |
+
else:
|
189 |
+
model_path = model
|
190 |
+
model_kwargs = get_model_kwargs(llamacpp_dict, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
191 |
+
model_kwargs.update(
|
192 |
+
dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
|
193 |
+
prompter=prompter, context=context, iinput=iinput))
|
194 |
+
llm = cls(**model_kwargs)
|
195 |
+
inner_model = llm.client
|
196 |
+
else:
|
197 |
+
raise RuntimeError("No such model_name %s" % model_name)
|
198 |
+
if inner_class:
|
199 |
+
return inner_model
|
200 |
+
else:
|
201 |
+
return llm
|
202 |
+
|
203 |
+
|
204 |
+
class H2OGPT4All(gpt4all.GPT4All):
|
205 |
+
model: Any
|
206 |
+
prompter: Any
|
207 |
+
context: Any = ''
|
208 |
+
iinput: Any = ''
|
209 |
+
"""Path to the pre-trained GPT4All model file."""
|
210 |
+
|
211 |
+
@root_validator()
|
212 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
213 |
+
"""Validate that the python package exists in the environment."""
|
214 |
+
try:
|
215 |
+
if isinstance(values["model"], str):
|
216 |
+
from gpt4all import GPT4All as GPT4AllModel
|
217 |
+
|
218 |
+
full_path = values["model"]
|
219 |
+
model_path, delimiter, model_name = full_path.rpartition("/")
|
220 |
+
model_path += delimiter
|
221 |
+
|
222 |
+
values["client"] = GPT4AllModel(
|
223 |
+
model_name=model_name,
|
224 |
+
model_path=model_path or None,
|
225 |
+
model_type=values["backend"],
|
226 |
+
allow_download=True,
|
227 |
+
)
|
228 |
+
if values["n_threads"] is not None:
|
229 |
+
# set n_threads
|
230 |
+
values["client"].model.set_thread_count(values["n_threads"])
|
231 |
+
else:
|
232 |
+
values["client"] = values["model"]
|
233 |
+
if values["n_threads"] is not None:
|
234 |
+
# set n_threads
|
235 |
+
values["client"].model.set_thread_count(values["n_threads"])
|
236 |
+
try:
|
237 |
+
values["backend"] = values["client"].model_type
|
238 |
+
except AttributeError:
|
239 |
+
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
|
240 |
+
values["backend"] = values["client"].model.model_type
|
241 |
+
|
242 |
+
except ImportError:
|
243 |
+
raise ValueError(
|
244 |
+
"Could not import gpt4all python package. "
|
245 |
+
"Please install it with `pip install gpt4all`."
|
246 |
+
)
|
247 |
+
return values
|
248 |
+
|
249 |
+
def _call(
|
250 |
+
self,
|
251 |
+
prompt: str,
|
252 |
+
stop: Optional[List[str]] = None,
|
253 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
254 |
+
**kwargs,
|
255 |
+
) -> str:
|
256 |
+
# Roughly 4 chars per token if natural language
|
257 |
+
n_ctx = 2048
|
258 |
+
prompt = prompt[-self.max_tokens * 4:]
|
259 |
+
|
260 |
+
# use instruct prompting
|
261 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
262 |
+
prompt = self.prompter.generate_prompt(data_point)
|
263 |
+
|
264 |
+
verbose = False
|
265 |
+
if verbose:
|
266 |
+
print("_call prompt: %s" % prompt, flush=True)
|
267 |
+
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
|
268 |
+
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
269 |
+
|
270 |
+
# FIXME: Unsure what uses
|
271 |
+
#def get_token_ids(self, text: str) -> List[int]:
|
272 |
+
# return self.client.tokenize(b" " + text.encode("utf-8"))
|
273 |
+
|
274 |
+
|
275 |
+
from langchain.llms import LlamaCpp
|
276 |
+
|
277 |
+
|
278 |
+
class H2OLlamaCpp(LlamaCpp):
|
279 |
+
model_path: Any
|
280 |
+
prompter: Any
|
281 |
+
context: Any
|
282 |
+
iinput: Any
|
283 |
+
"""Path to the pre-trained GPT4All model file."""
|
284 |
+
|
285 |
+
@root_validator()
|
286 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
287 |
+
"""Validate that llama-cpp-python library is installed."""
|
288 |
+
if isinstance(values["model_path"], str):
|
289 |
+
model_path = values["model_path"]
|
290 |
+
model_param_names = [
|
291 |
+
"lora_path",
|
292 |
+
"lora_base",
|
293 |
+
"n_ctx",
|
294 |
+
"n_parts",
|
295 |
+
"seed",
|
296 |
+
"f16_kv",
|
297 |
+
"logits_all",
|
298 |
+
"vocab_only",
|
299 |
+
"use_mlock",
|
300 |
+
"n_threads",
|
301 |
+
"n_batch",
|
302 |
+
"use_mmap",
|
303 |
+
"last_n_tokens_size",
|
304 |
+
]
|
305 |
+
model_params = {k: values[k] for k in model_param_names}
|
306 |
+
# For backwards compatibility, only include if non-null.
|
307 |
+
if values["n_gpu_layers"] is not None:
|
308 |
+
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
309 |
+
|
310 |
+
try:
|
311 |
+
try:
|
312 |
+
from llama_cpp import Llama
|
313 |
+
except ImportError:
|
314 |
+
from llama_cpp_cuda import Llama
|
315 |
+
|
316 |
+
values["client"] = Llama(model_path, **model_params)
|
317 |
+
except ImportError:
|
318 |
+
raise ModuleNotFoundError(
|
319 |
+
"Could not import llama-cpp-python library. "
|
320 |
+
"Please install the llama-cpp-python library to "
|
321 |
+
"use this embedding model: pip install llama-cpp-python"
|
322 |
+
)
|
323 |
+
except Exception as e:
|
324 |
+
raise ValueError(
|
325 |
+
f"Could not load Llama model from path: {model_path}. "
|
326 |
+
f"Received error {e}"
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
values["client"] = values["model_path"]
|
330 |
+
return values
|
331 |
+
|
332 |
+
def _call(
|
333 |
+
self,
|
334 |
+
prompt: str,
|
335 |
+
stop: Optional[List[str]] = None,
|
336 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
337 |
+
**kwargs,
|
338 |
+
) -> str:
|
339 |
+
verbose = False
|
340 |
+
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
341 |
+
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
|
342 |
+
prompt = prompt[-self.n_ctx * 4:]
|
343 |
+
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
344 |
+
num_prompt_tokens = len(prompt_tokens)
|
345 |
+
if num_prompt_tokens > self.n_ctx:
|
346 |
+
# conservative by using int()
|
347 |
+
chars_per_token = int(len(prompt) / num_prompt_tokens)
|
348 |
+
prompt = prompt[-self.n_ctx * chars_per_token:]
|
349 |
+
if verbose:
|
350 |
+
print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
|
351 |
+
prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
352 |
+
num_prompt_tokens2 = len(prompt_tokens2)
|
353 |
+
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
354 |
+
|
355 |
+
# use instruct prompting
|
356 |
+
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
357 |
+
prompt = self.prompter.generate_prompt(data_point)
|
358 |
+
|
359 |
+
if verbose:
|
360 |
+
print("_call prompt: %s" % prompt, flush=True)
|
361 |
+
|
362 |
+
if self.streaming:
|
363 |
+
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
364 |
+
text = ""
|
365 |
+
for token in self.stream(input=prompt, stop=stop):
|
366 |
+
# for token in self.stream(input=prompt, stop=stop, run_manager=run_manager):
|
367 |
+
text_chunk = token # ["choices"][0]["text"]
|
368 |
+
# self.stream already calls text_callback
|
369 |
+
# if text_callback:
|
370 |
+
# text_callback(text_chunk)
|
371 |
+
text += text_chunk
|
372 |
+
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
373 |
+
return text[len(prompt):]
|
374 |
+
else:
|
375 |
+
params = self._get_parameters(stop)
|
376 |
+
params = {**params, **kwargs}
|
377 |
+
result = self.client(prompt=prompt, **params)
|
378 |
+
return result["choices"][0]["text"]
|
379 |
+
|
380 |
+
def _stream(
|
381 |
+
self,
|
382 |
+
prompt: str,
|
383 |
+
stop: Optional[List[str]] = None,
|
384 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
385 |
+
**kwargs: Any,
|
386 |
+
) -> Iterator[GenerationChunk]:
|
387 |
+
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
388 |
+
logprobs = 0
|
389 |
+
chunk = GenerationChunk(
|
390 |
+
text=prompt,
|
391 |
+
generation_info={"logprobs": logprobs},
|
392 |
+
)
|
393 |
+
yield chunk
|
394 |
+
if run_manager:
|
395 |
+
run_manager.on_llm_new_token(
|
396 |
+
token=chunk.text, verbose=self.verbose, log_probs=logprobs
|
397 |
+
)
|
398 |
+
# actual new tokens
|
399 |
+
for chunk in super()._stream(prompt, stop=stop, run_manager=run_manager, **kwargs):
|
400 |
+
yield chunk
|
401 |
+
|
402 |
+
def get_token_ids(self, text: str) -> List[int]:
|
403 |
+
return self.client.tokenize(b" " + text.encode("utf-8"))
|
src/gpt_langchain.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/gradio_runner.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/gradio_themes.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Iterable
|
4 |
+
|
5 |
+
from gradio.themes.soft import Soft
|
6 |
+
from gradio.themes import Color, Size
|
7 |
+
from gradio.themes.utils import colors, sizes, fonts
|
8 |
+
|
9 |
+
h2o_yellow = Color(
|
10 |
+
name="yellow",
|
11 |
+
c50="#fffef2",
|
12 |
+
c100="#fff9e6",
|
13 |
+
c200="#ffecb3",
|
14 |
+
c300="#ffe28c",
|
15 |
+
c400="#ffd659",
|
16 |
+
c500="#fec925",
|
17 |
+
c600="#e6ac00",
|
18 |
+
c700="#bf8f00",
|
19 |
+
c800="#a67c00",
|
20 |
+
c900="#664d00",
|
21 |
+
c950="#403000",
|
22 |
+
)
|
23 |
+
h2o_gray = Color(
|
24 |
+
name="gray",
|
25 |
+
c50="#f8f8f8",
|
26 |
+
c100="#e5e5e5",
|
27 |
+
c200="#cccccc",
|
28 |
+
c300="#b2b2b2",
|
29 |
+
c400="#999999",
|
30 |
+
c500="#7f7f7f",
|
31 |
+
c600="#666666",
|
32 |
+
c700="#4c4c4c",
|
33 |
+
c800="#333333",
|
34 |
+
c900="#191919",
|
35 |
+
c950="#0d0d0d",
|
36 |
+
)
|
37 |
+
|
38 |
+
text_xsm = Size(
|
39 |
+
name="text_xsm",
|
40 |
+
xxs="4px",
|
41 |
+
xs="5px",
|
42 |
+
sm="6px",
|
43 |
+
md="7px",
|
44 |
+
lg="8px",
|
45 |
+
xl="10px",
|
46 |
+
xxl="12px",
|
47 |
+
)
|
48 |
+
|
49 |
+
spacing_xsm = Size(
|
50 |
+
name="spacing_xsm",
|
51 |
+
xxs="1px",
|
52 |
+
xs="1px",
|
53 |
+
sm="1px",
|
54 |
+
md="2px",
|
55 |
+
lg="3px",
|
56 |
+
xl="5px",
|
57 |
+
xxl="7px",
|
58 |
+
)
|
59 |
+
|
60 |
+
radius_xsm = Size(
|
61 |
+
name="radius_xsm",
|
62 |
+
xxs="1px",
|
63 |
+
xs="1px",
|
64 |
+
sm="1px",
|
65 |
+
md="2px",
|
66 |
+
lg="3px",
|
67 |
+
xl="5px",
|
68 |
+
xxl="7px",
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
class H2oTheme(Soft):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
*,
|
76 |
+
primary_hue: colors.Color | str = h2o_yellow,
|
77 |
+
secondary_hue: colors.Color | str = h2o_yellow,
|
78 |
+
neutral_hue: colors.Color | str = h2o_gray,
|
79 |
+
spacing_size: sizes.Size | str = sizes.spacing_md,
|
80 |
+
radius_size: sizes.Size | str = sizes.radius_md,
|
81 |
+
text_size: sizes.Size | str = sizes.text_lg,
|
82 |
+
font: fonts.Font
|
83 |
+
| str
|
84 |
+
| Iterable[fonts.Font | str] = (
|
85 |
+
fonts.GoogleFont("Montserrat"),
|
86 |
+
"ui-sans-serif",
|
87 |
+
"system-ui",
|
88 |
+
"sans-serif",
|
89 |
+
),
|
90 |
+
font_mono: fonts.Font
|
91 |
+
| str
|
92 |
+
| Iterable[fonts.Font | str] = (
|
93 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
94 |
+
"ui-monospace",
|
95 |
+
"Consolas",
|
96 |
+
"monospace",
|
97 |
+
),
|
98 |
+
):
|
99 |
+
super().__init__(
|
100 |
+
primary_hue=primary_hue,
|
101 |
+
secondary_hue=secondary_hue,
|
102 |
+
neutral_hue=neutral_hue,
|
103 |
+
spacing_size=spacing_size,
|
104 |
+
radius_size=radius_size,
|
105 |
+
text_size=text_size,
|
106 |
+
font=font,
|
107 |
+
font_mono=font_mono,
|
108 |
+
)
|
109 |
+
super().set(
|
110 |
+
background_fill_primary_dark="*block_background_fill",
|
111 |
+
block_background_fill_dark="*neutral_950",
|
112 |
+
block_border_width='1px',
|
113 |
+
block_border_width_dark='1px',
|
114 |
+
block_label_background_fill="*primary_300",
|
115 |
+
block_label_background_fill_dark="*primary_600",
|
116 |
+
block_label_text_color="*neutral_950",
|
117 |
+
block_label_text_color_dark="*neutral_950",
|
118 |
+
block_radius="0 0 8px 8px",
|
119 |
+
block_title_text_color="*neutral_950",
|
120 |
+
block_title_text_color_dark="*neutral_950",
|
121 |
+
body_background_fill="*neutral_50",
|
122 |
+
body_background_fill_dark="*neutral_900",
|
123 |
+
border_color_primary="*neutral_100",
|
124 |
+
border_color_primary_dark="*neutral_700",
|
125 |
+
button_border_width="1px",
|
126 |
+
button_border_width_dark="1px",
|
127 |
+
button_primary_text_color="*neutral_950",
|
128 |
+
button_primary_text_color_dark="*neutral_950",
|
129 |
+
button_primary_background_fill="*primary_500",
|
130 |
+
button_primary_background_fill_dark="*primary_500",
|
131 |
+
button_secondary_background_fill_hover_dark="*primary_700",
|
132 |
+
button_secondary_border_color="*primary_500",
|
133 |
+
button_secondary_border_color_dark="*primary_500",
|
134 |
+
button_secondary_border_color_hover_dark="*primary_700",
|
135 |
+
checkbox_label_text_color_selected_dark='#000000',
|
136 |
+
# checkbox_label_text_size="*text_xs", # too small for iPhone etc. but good if full large screen zoomed to fit
|
137 |
+
checkbox_label_text_size="*text_sm",
|
138 |
+
# radio_circle="""url("data:image/svg+xml,%3csvg viewBox='0 0 32 32' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='32' cy='32' r='1'/%3e%3c/svg%3e")""",
|
139 |
+
# checkbox_border_width=1,
|
140 |
+
# heckbox_border_width_dark=1,
|
141 |
+
link_text_color="#3344DD",
|
142 |
+
link_text_color_hover="#3344DD",
|
143 |
+
link_text_color_visited="#3344DD",
|
144 |
+
link_text_color_dark="#74abff",
|
145 |
+
link_text_color_hover_dark="#a3c8ff",
|
146 |
+
link_text_color_active_dark="#a3c8ff",
|
147 |
+
link_text_color_visited_dark="#74abff",
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
class SoftTheme(Soft):
|
152 |
+
def __init__(
|
153 |
+
self,
|
154 |
+
*,
|
155 |
+
primary_hue: colors.Color | str = colors.indigo,
|
156 |
+
secondary_hue: colors.Color | str = colors.indigo,
|
157 |
+
neutral_hue: colors.Color | str = colors.gray,
|
158 |
+
spacing_size: sizes.Size | str = sizes.spacing_md,
|
159 |
+
radius_size: sizes.Size | str = sizes.radius_md,
|
160 |
+
text_size: sizes.Size | str = sizes.text_md,
|
161 |
+
font: fonts.Font
|
162 |
+
| str
|
163 |
+
| Iterable[fonts.Font | str] = (
|
164 |
+
fonts.GoogleFont("Montserrat"),
|
165 |
+
"ui-sans-serif",
|
166 |
+
"system-ui",
|
167 |
+
"sans-serif",
|
168 |
+
),
|
169 |
+
font_mono: fonts.Font
|
170 |
+
| str
|
171 |
+
| Iterable[fonts.Font | str] = (
|
172 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
173 |
+
"ui-monospace",
|
174 |
+
"Consolas",
|
175 |
+
"monospace",
|
176 |
+
),
|
177 |
+
):
|
178 |
+
super().__init__(
|
179 |
+
primary_hue=primary_hue,
|
180 |
+
secondary_hue=secondary_hue,
|
181 |
+
neutral_hue=neutral_hue,
|
182 |
+
spacing_size=spacing_size,
|
183 |
+
radius_size=radius_size,
|
184 |
+
text_size=text_size,
|
185 |
+
font=font,
|
186 |
+
font_mono=font_mono,
|
187 |
+
)
|
188 |
+
super().set(
|
189 |
+
checkbox_label_text_size="*text_sm",
|
190 |
+
)
|
191 |
+
|
192 |
+
|
193 |
+
h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
|
194 |
+
' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:' \
|
195 |
+
'#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" ' \
|
196 |
+
'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.' \
|
197 |
+
'47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.' \
|
198 |
+
'82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10' \
|
199 |
+
'.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"' \
|
200 |
+
'/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.' \
|
201 |
+
'76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69' \
|
202 |
+
',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.' \
|
203 |
+
'85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.' \
|
204 |
+
'69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30.' \
|
205 |
+
'62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8.' \
|
206 |
+
'62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-' \
|
207 |
+
'12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"' \
|
208 |
+
' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,' \
|
209 |
+
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
|
210 |
+
|
211 |
+
|
212 |
+
def get_h2o_title(title, description):
|
213 |
+
# NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
|
214 |
+
return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
|
215 |
+
{description}
|
216 |
+
</div>
|
217 |
+
<div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
|
218 |
+
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
|
219 |
+
<h1 style="line-height:60px">{title}</h1>
|
220 |
+
</div>
|
221 |
+
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
222 |
+
<img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
|
223 |
+
</div>
|
224 |
+
"""
|
225 |
+
|
226 |
+
|
227 |
+
def get_simple_title(title, description):
|
228 |
+
return f"""{description}<h1 align="center"> {title}</h1>"""
|
229 |
+
|
230 |
+
|
231 |
+
def get_dark_js() -> str:
|
232 |
+
return """
|
233 |
+
if (document.querySelectorAll('.dark').length) {
|
234 |
+
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
|
235 |
+
} else {
|
236 |
+
document.querySelector('body').classList.add('dark');
|
237 |
+
}
|
238 |
+
"""
|
239 |
+
|
240 |
+
|
241 |
+
def get_heap_js(heapAppId: str) -> str:
|
242 |
+
return (
|
243 |
+
"""globalThis.window.heap=window.heap||[],heap.load=function(e,t){window.heap.appid=e,window.heap.config=t=t||{};var r=document.createElement("script");r.type="text/javascript",r.async=!0,r.src="https://cdn.heapanalytics.com/js/heap-"+e+".js";var a=document.getElementsByTagName("script")[0];a.parentNode.insertBefore(r,a);for(var n=function(e){return function(){heap.push([e].concat(Array.prototype.slice.call(arguments,0)))}},p=["addEventProperties","addUserProperties","clearEventProperties","identify","resetIdentity","removeEventProperty","setEventProperties","track","unsetEventProperty"],o=0;o<p.length;o++)heap[p[o]]=n(p[o])};"""
|
244 |
+
f"""heap.load("{heapAppId}");""")
|
245 |
+
|
246 |
+
|
247 |
+
def wrap_js_to_lambda(num_params: int, *args: str) -> str:
|
248 |
+
"""
|
249 |
+
Generates a JS code representing JS lambda that wraps all given '*args' code strings.
|
250 |
+
The lambda function has number of parameters based on 'num_params' and returns them
|
251 |
+
without modification in an array. Lambda with zero parameters returns an empty array.
|
252 |
+
"""
|
253 |
+
params = ", ".join([f"p{i}" for i in range(num_params)])
|
254 |
+
newline = "\n"
|
255 |
+
return f"""
|
256 |
+
({params}) => {{
|
257 |
+
{newline.join([a for a in args if a is not None])}
|
258 |
+
return [{params}];
|
259 |
+
}}
|
260 |
+
"""
|
src/gradio_utils/__init__.py
ADDED
File without changes
|
src/gradio_utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (134 Bytes). View file
|
|
src/gradio_utils/__pycache__/css.cpython-310.pyc
ADDED
Binary file (3.65 kB). View file
|
|
src/gradio_utils/__pycache__/grclient.cpython-310.pyc
ADDED
Binary file (2.69 kB). View file
|
|
src/gradio_utils/__pycache__/prompt_form.cpython-310.pyc
ADDED
Binary file (2.96 kB). View file
|
|
src/gradio_utils/css.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_css(kwargs) -> str:
|
2 |
+
if kwargs['h2ocolors']:
|
3 |
+
css_code = """footer {visibility: hidden;}
|
4 |
+
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
5 |
+
body.dark{background:linear-gradient(#000000,#0d0d0d);}
|
6 |
+
"""
|
7 |
+
else:
|
8 |
+
css_code = """footer {visibility: hidden}"""
|
9 |
+
|
10 |
+
css_code += make_css_base()
|
11 |
+
return css_code
|
12 |
+
|
13 |
+
|
14 |
+
def make_css_base() -> str:
|
15 |
+
return """
|
16 |
+
#col_container {margin-left: auto; margin-right: auto; text-align: left;}
|
17 |
+
|
18 |
+
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
19 |
+
|
20 |
+
body.dark{#warning {background-color: #555555};}
|
21 |
+
|
22 |
+
#sidebar {
|
23 |
+
order: 1;
|
24 |
+
|
25 |
+
@media (max-width: 463px) {
|
26 |
+
order: 2;
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
#col-tabs {
|
31 |
+
order: 2;
|
32 |
+
|
33 |
+
@media (max-width: 463px) {
|
34 |
+
order: 1;
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
#small_btn {
|
39 |
+
margin: 0.6em 0em 0.55em 0;
|
40 |
+
max-width: 20em;
|
41 |
+
min-width: 5em !important;
|
42 |
+
height: 5em;
|
43 |
+
font-size: 14px !important;
|
44 |
+
}
|
45 |
+
|
46 |
+
#prompt-form {
|
47 |
+
border: 1px solid var(--primary-500) !important;
|
48 |
+
}
|
49 |
+
|
50 |
+
#prompt-form.block {
|
51 |
+
border-radius: var(--block-radius) !important;
|
52 |
+
}
|
53 |
+
|
54 |
+
#prompt-form textarea {
|
55 |
+
border: 1px solid rgb(209, 213, 219);
|
56 |
+
}
|
57 |
+
|
58 |
+
#prompt-form label > div {
|
59 |
+
margin-top: 4px;
|
60 |
+
}
|
61 |
+
|
62 |
+
button.primary:hover {
|
63 |
+
background-color: var(--primary-600) !important;
|
64 |
+
transition: .2s;
|
65 |
+
}
|
66 |
+
|
67 |
+
#prompt-form-area {
|
68 |
+
margin-bottom: 2.5rem;
|
69 |
+
}
|
70 |
+
.chatsmall chatbot {font-size: 10px !important}
|
71 |
+
|
72 |
+
.gradio-container {
|
73 |
+
max-width: none !important;
|
74 |
+
}
|
75 |
+
|
76 |
+
div.message {
|
77 |
+
padding: var(--text-lg) !important;
|
78 |
+
}
|
79 |
+
|
80 |
+
div.message.user > div.icon-button {
|
81 |
+
top: unset;
|
82 |
+
bottom: 0;
|
83 |
+
}
|
84 |
+
|
85 |
+
div.message.bot > div.icon-button {
|
86 |
+
top: unset;
|
87 |
+
bottom: 0;
|
88 |
+
}
|
89 |
+
|
90 |
+
#prompt-form-row {
|
91 |
+
position: relative;
|
92 |
+
}
|
93 |
+
|
94 |
+
#attach-button {
|
95 |
+
position: absolute;
|
96 |
+
top: 45px;
|
97 |
+
right: 20px;
|
98 |
+
|
99 |
+
display: flex;
|
100 |
+
justify-content: center;
|
101 |
+
border: 1px solid var(--primary-500) !important;
|
102 |
+
|
103 |
+
@media (max-width: 463px) {
|
104 |
+
width: 56px;
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
#attach-button > img {
|
109 |
+
margin-right: 0;
|
110 |
+
}
|
111 |
+
|
112 |
+
#prompt-form > label > textarea {
|
113 |
+
padding-right: 104px;
|
114 |
+
|
115 |
+
@media (max-width: 463px) {
|
116 |
+
min-height: 94px;
|
117 |
+
padding-right: 70px;
|
118 |
+
}
|
119 |
+
}
|
120 |
+
|
121 |
+
#visible-models > label > div.wrap > div.wrap-inner > div.secondary-wrap > div.remove-all {
|
122 |
+
display: none !important;
|
123 |
+
}
|
124 |
+
|
125 |
+
#visible-models > label > div.wrap > div.wrap-inner > div.token {
|
126 |
+
display: none !important;
|
127 |
+
}
|
128 |
+
|
129 |
+
#visible-models > label > div.wrap > div.wrap-inner > div.secondary-wrap::before {
|
130 |
+
content: "Select";
|
131 |
+
padding: 0 4px;
|
132 |
+
margin-right: 2px;
|
133 |
+
}
|
134 |
+
|
135 |
+
#langchain_agents > label > div.wrap > div.wrap-inner > div.secondary-wrap > div.remove-all {
|
136 |
+
display: none !important;
|
137 |
+
}
|
138 |
+
|
139 |
+
#langchain_agents > label > div.wrap > div.wrap-inner > div.token {
|
140 |
+
display: none !important;
|
141 |
+
}
|
142 |
+
|
143 |
+
#langchain_agents > label > div.wrap > div.wrap-inner > div.secondary-wrap::before {
|
144 |
+
content: "Select";
|
145 |
+
padding: 0 4px;
|
146 |
+
margin-right: 2px;
|
147 |
+
}
|
148 |
+
"""
|
src/gradio_utils/grclient.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from typing import Callable
|
3 |
+
import os
|
4 |
+
|
5 |
+
from gradio_client.client import Job
|
6 |
+
|
7 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
8 |
+
|
9 |
+
from gradio_client import Client
|
10 |
+
|
11 |
+
|
12 |
+
class GradioClient(Client):
|
13 |
+
"""
|
14 |
+
Parent class of gradio client
|
15 |
+
To handle automatically refreshing client if detect gradio server changed
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, *args, **kwargs):
|
19 |
+
self.args = args
|
20 |
+
self.kwargs = kwargs
|
21 |
+
super().__init__(*args, **kwargs)
|
22 |
+
self.server_hash = self.get_server_hash()
|
23 |
+
|
24 |
+
def get_server_hash(self):
|
25 |
+
"""
|
26 |
+
Get server hash using super without any refresh action triggered
|
27 |
+
Returns: git hash of gradio server
|
28 |
+
"""
|
29 |
+
return super().submit(api_name='/system_hash').result()
|
30 |
+
|
31 |
+
def refresh_client_if_should(self):
|
32 |
+
# get current hash in order to update api_name -> fn_index map in case gradio server changed
|
33 |
+
# FIXME: Could add cli api as hash
|
34 |
+
server_hash = self.get_server_hash()
|
35 |
+
if self.server_hash != server_hash:
|
36 |
+
self.refresh_client()
|
37 |
+
self.server_hash = server_hash
|
38 |
+
else:
|
39 |
+
self.reset_session()
|
40 |
+
|
41 |
+
def refresh_client(self):
|
42 |
+
"""
|
43 |
+
Ensure every client call is independent
|
44 |
+
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
|
45 |
+
Returns:
|
46 |
+
"""
|
47 |
+
# need session hash to be new every time, to avoid "generator already executing"
|
48 |
+
self.reset_session()
|
49 |
+
|
50 |
+
client = Client(*self.args, **self.kwargs)
|
51 |
+
for k, v in client.__dict__.items():
|
52 |
+
setattr(self, k, v)
|
53 |
+
|
54 |
+
def submit(
|
55 |
+
self,
|
56 |
+
*args,
|
57 |
+
api_name: str | None = None,
|
58 |
+
fn_index: int | None = None,
|
59 |
+
result_callbacks: Callable | list[Callable] | None = None,
|
60 |
+
) -> Job:
|
61 |
+
# Note predict calls submit
|
62 |
+
try:
|
63 |
+
self.refresh_client_if_should()
|
64 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
65 |
+
except Exception as e:
|
66 |
+
print("Hit e=%s" % str(e), flush=True)
|
67 |
+
# force reconfig in case only that
|
68 |
+
self.refresh_client()
|
69 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
70 |
+
|
71 |
+
# see if immediately failed
|
72 |
+
e = job.future._exception
|
73 |
+
if e is not None:
|
74 |
+
print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True)
|
75 |
+
# force reconfig in case only that
|
76 |
+
self.refresh_client()
|
77 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
78 |
+
e2 = job.future._exception
|
79 |
+
if e2 is not None:
|
80 |
+
print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True)
|
81 |
+
|
82 |
+
return job
|
src/gradio_utils/prompt_form.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def make_chatbots(output_label0, output_label0_model2, **kwargs):
|
8 |
+
visible_models = kwargs['visible_models']
|
9 |
+
all_models = kwargs['all_models']
|
10 |
+
|
11 |
+
text_outputs = []
|
12 |
+
chat_kwargs = []
|
13 |
+
for model_state_locki, model_state_lock in enumerate(kwargs['model_states']):
|
14 |
+
if os.environ.get('DEBUG_MODEL_LOCK'):
|
15 |
+
model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
|
16 |
+
else:
|
17 |
+
model_name = model_state_lock["base_model"]
|
18 |
+
output_label = f'h2oGPT [{model_name}]'
|
19 |
+
min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
|
20 |
+
chat_kwargs.append(dict(label=output_label, elem_classes='chatsmall',
|
21 |
+
height=kwargs['height'] or 400, min_width=min_width,
|
22 |
+
show_copy_button=kwargs['show_copy_button'],
|
23 |
+
visible=kwargs['model_lock'] and (visible_models is None or
|
24 |
+
model_state_locki in visible_models or
|
25 |
+
all_models[model_state_locki] in visible_models
|
26 |
+
)))
|
27 |
+
|
28 |
+
# base view on initial visible choice
|
29 |
+
if visible_models:
|
30 |
+
len_visible = len(visible_models)
|
31 |
+
else:
|
32 |
+
len_visible = len(kwargs['model_states'])
|
33 |
+
if kwargs['model_lock_columns'] == -1:
|
34 |
+
kwargs['model_lock_columns'] = len_visible
|
35 |
+
if kwargs['model_lock_columns'] is None:
|
36 |
+
kwargs['model_lock_columns'] = 3
|
37 |
+
|
38 |
+
ncols = kwargs['model_lock_columns']
|
39 |
+
if kwargs['model_states'] == 0:
|
40 |
+
nrows = 0
|
41 |
+
else:
|
42 |
+
nrows = math.ceil(len_visible / kwargs['model_lock_columns'])
|
43 |
+
|
44 |
+
if kwargs['model_lock_columns'] == 0:
|
45 |
+
# not using model_lock
|
46 |
+
pass
|
47 |
+
elif nrows <= 1:
|
48 |
+
with gr.Row():
|
49 |
+
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
50 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
51 |
+
elif nrows == kwargs['model_states']:
|
52 |
+
with gr.Row():
|
53 |
+
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
54 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
55 |
+
elif nrows == 2:
|
56 |
+
with gr.Row():
|
57 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
58 |
+
if mii >= len_visible / 2:
|
59 |
+
continue
|
60 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
61 |
+
with gr.Row():
|
62 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
63 |
+
if mii < len_visible / 2:
|
64 |
+
continue
|
65 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
66 |
+
elif nrows == 3:
|
67 |
+
with gr.Row():
|
68 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
69 |
+
if mii >= 1 * len_visible / 3:
|
70 |
+
continue
|
71 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
72 |
+
with gr.Row():
|
73 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
74 |
+
if mii < 1 * len_visible / 3 or mii >= 2 * len_visible / 3:
|
75 |
+
continue
|
76 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
77 |
+
with gr.Row():
|
78 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
79 |
+
if mii < 2 * len_visible / 3:
|
80 |
+
continue
|
81 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
82 |
+
elif nrows >= 4:
|
83 |
+
with gr.Row():
|
84 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
85 |
+
if mii >= 1 * len_visible / 4:
|
86 |
+
continue
|
87 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
88 |
+
with gr.Row():
|
89 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
90 |
+
if mii < 1 * len_visible / 4 or mii >= 2 * len_visible / 4:
|
91 |
+
continue
|
92 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
93 |
+
with gr.Row():
|
94 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
95 |
+
if mii < 2 * len_visible / 4 or mii >= 3 * len_visible / 4:
|
96 |
+
continue
|
97 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
98 |
+
with gr.Row():
|
99 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
100 |
+
if mii < 3 * len_visible / 4:
|
101 |
+
continue
|
102 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
103 |
+
|
104 |
+
with gr.Row():
|
105 |
+
text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
|
106 |
+
text_output2 = gr.Chatbot(label=output_label0_model2,
|
107 |
+
visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
|
108 |
+
return text_output, text_output2, text_outputs
|
src/h2o-logo.svg
ADDED