Spaces:
Runtime error
Runtime error
Upload 41 files
Browse files- .gitattributes +6 -0
- LICENSE +201 -0
- app.py +115 -0
- assets/attention_all_you_need.pdf +0 -0
- assets/attention_short.pdf +0 -0
- assets/doc_gif.gif +3 -0
- assets/dog_monalisa.jpeg +0 -0
- assets/image_demo.gif +3 -0
- assets/image_doc.gif +3 -0
- assets/image_doc_rag.gif +3 -0
- assets/rag_gif.gif +3 -0
- assets/text_completion_gif.gif +3 -0
- assets/upload_chat.json +10 -0
- assets/upload_few_shot.json +10 -0
- llama_cpp_requirements.txt +1 -0
- mlx_requirements.txt +2 -0
- multipurpose_chatbot/.DS_Store +0 -0
- multipurpose_chatbot/__init__.py +0 -0
- multipurpose_chatbot/configs.py +110 -0
- multipurpose_chatbot/demos/.DS_Store +0 -0
- multipurpose_chatbot/demos/__init__.py +8 -0
- multipurpose_chatbot/demos/base_demo.py +105 -0
- multipurpose_chatbot/demos/batch_inference.py +246 -0
- multipurpose_chatbot/demos/chat_interface.py +704 -0
- multipurpose_chatbot/demos/multimodal_chat_interface.py +1293 -0
- multipurpose_chatbot/demos/rag_chat_interface.py +642 -0
- multipurpose_chatbot/demos/text_completion.py +199 -0
- multipurpose_chatbot/engines/.DS_Store +0 -0
- multipurpose_chatbot/engines/__init__.py +54 -0
- multipurpose_chatbot/engines/base_engine.py +46 -0
- multipurpose_chatbot/engines/debug_engine.py +49 -0
- multipurpose_chatbot/engines/llama_cpp_engine.py +131 -0
- multipurpose_chatbot/engines/llava15_transformers_engine.py +230 -0
- multipurpose_chatbot/engines/llava_llama_cpp_engine.py +280 -0
- multipurpose_chatbot/engines/mlx_engine.py +202 -0
- multipurpose_chatbot/engines/transformers_engine.py +452 -0
- multipurpose_chatbot/engines/vllm_engine.py +233 -0
- multipurpose_chatbot/globals.py +33 -0
- pyproject.toml +0 -0
- requirements.txt +11 -0
- transformers_requirements.txt +1 -0
- vllm_requirements.txt +2 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/doc_gif.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/image_demo.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/image_doc_rag.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/image_doc.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/rag_gif.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/text_completion_gif.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright: DAMO Academy, Alibaba Group
|
2 |
+
# By Xuan Phi Nguyen at DAMO Academy, Alibaba Group
|
3 |
+
|
4 |
+
# Description:
|
5 |
+
"""
|
6 |
+
Demo script to launch Language chat model
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
import os
|
11 |
+
from gradio.themes import ThemeClass as Theme
|
12 |
+
import numpy as np
|
13 |
+
import argparse
|
14 |
+
# import torch
|
15 |
+
import gradio as gr
|
16 |
+
from typing import Any, Iterator
|
17 |
+
from typing import Iterator, List, Optional, Tuple
|
18 |
+
import filelock
|
19 |
+
import glob
|
20 |
+
import json
|
21 |
+
import time
|
22 |
+
from gradio.routes import Request
|
23 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
24 |
+
from gradio.helpers import special_args
|
25 |
+
import anyio
|
26 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
27 |
+
|
28 |
+
from gradio_client.documentation import document, set_documentation_group
|
29 |
+
|
30 |
+
from typing import List, Optional, Union, Dict, Tuple
|
31 |
+
from tqdm.auto import tqdm
|
32 |
+
from huggingface_hub import snapshot_download
|
33 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
34 |
+
from gradio.components import Button, Component
|
35 |
+
from gradio.events import Dependency, EventListenerMethod
|
36 |
+
|
37 |
+
from multipurpose_chatbot.demos.base_demo import CustomTabbedInterface
|
38 |
+
|
39 |
+
from multipurpose_chatbot.configs import (
|
40 |
+
MODEL_TITLE,
|
41 |
+
MODEL_DESC,
|
42 |
+
MODEL_INFO,
|
43 |
+
CITE_MARKDOWN,
|
44 |
+
ALLOWED_PATHS,
|
45 |
+
PROXY,
|
46 |
+
PORT,
|
47 |
+
MODEL_PATH,
|
48 |
+
MODEL_NAME,
|
49 |
+
BACKEND,
|
50 |
+
DEMOS,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
demo = None
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
def launch_demo():
|
61 |
+
global demo, MODEL_ENGINE
|
62 |
+
model_desc = MODEL_DESC
|
63 |
+
model_path = MODEL_PATH
|
64 |
+
|
65 |
+
print(f'Begin importing models')
|
66 |
+
from multipurpose_chatbot.demos import get_demo_class
|
67 |
+
|
68 |
+
# demos = {
|
69 |
+
# k: get_demo_class(k)().create_demo()
|
70 |
+
# for k in demo_and_tab_names.keys()
|
71 |
+
# }
|
72 |
+
print(f'{DEMOS=}')
|
73 |
+
demo_class_objects = {
|
74 |
+
k: get_demo_class(k)()
|
75 |
+
for k in DEMOS
|
76 |
+
}
|
77 |
+
demos = {
|
78 |
+
k: get_demo_class(k)().create_demo()
|
79 |
+
for k in DEMOS
|
80 |
+
}
|
81 |
+
demos_names = [x.tab_name for x in demo_class_objects.values()]
|
82 |
+
|
83 |
+
descriptions = model_desc
|
84 |
+
if MODEL_INFO is not None and MODEL_INFO != "":
|
85 |
+
descriptions += (
|
86 |
+
f"<br>" +
|
87 |
+
MODEL_INFO.format(model_path=model_path)
|
88 |
+
)
|
89 |
+
|
90 |
+
demo = CustomTabbedInterface(
|
91 |
+
interface_list=list(demos.values()),
|
92 |
+
tab_names=demos_names,
|
93 |
+
title=f"{MODEL_TITLE}",
|
94 |
+
description=descriptions,
|
95 |
+
)
|
96 |
+
|
97 |
+
demo.title = MODEL_NAME
|
98 |
+
|
99 |
+
with demo:
|
100 |
+
gr.Markdown(CITE_MARKDOWN)
|
101 |
+
|
102 |
+
demo.queue(api_open=False)
|
103 |
+
return demo
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
demo = launch_demo()
|
109 |
+
if PROXY is not None and PROXY != "":
|
110 |
+
print(f'{PROXY=} {PORT=}')
|
111 |
+
demo.launch(server_port=PORT, root_path=PROXY, show_api=False, allowed_paths=ALLOWED_PATHS)
|
112 |
+
else:
|
113 |
+
demo.launch(server_port=PORT, show_api=False, allowed_paths=ALLOWED_PATHS)
|
114 |
+
|
115 |
+
|
assets/attention_all_you_need.pdf
ADDED
Binary file (858 kB). View file
|
|
assets/attention_short.pdf
ADDED
Binary file (236 kB). View file
|
|
assets/doc_gif.gif
ADDED
Git LFS Details
|
assets/dog_monalisa.jpeg
ADDED
assets/image_demo.gif
ADDED
Git LFS Details
|
assets/image_doc.gif
ADDED
Git LFS Details
|
assets/image_doc_rag.gif
ADDED
Git LFS Details
|
assets/rag_gif.gif
ADDED
Git LFS Details
|
assets/text_completion_gif.gif
ADDED
Git LFS Details
|
assets/upload_chat.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"id": "1",
|
4 |
+
"prompt": "Tell me something about AI?"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"id": "2",
|
8 |
+
"prompt": "Who are you?"
|
9 |
+
}
|
10 |
+
]
|
assets/upload_few_shot.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"id": "0",
|
4 |
+
"prompt": "Translate Indonesian to English.\nIndonesian: \"Mereka melakukan hal ini dengan cara memancarkan sebuah partikel kecil cahaya kecil yang biasa disebut \"foton\".\"\nEnglish: They do this by emitting a tiny particle of light called a \"photon\".\n\nTranslate Indonesian to English.\nIndonesian: Kami melewati waktu seperti rangkaian peristiwa yang berlalu dari masa depan hingga masa kini lalu ke masa lalu.\nEnglish: We experience time as a series of events passing from the future through the present to the past.\n\nTranslate Indonesian to English.\nIndonesian: Canyoning (atau: canyoneering) adalah segala aktivitas yang terjadi di dasar ngarai, yang kering atau penuh air.\nEnglish: Canyoning (or: canyoneering) is about going in a bottom of a canyon, which is either dry or full of water.\n\nTranslate Indonesian to English.\nIndonesian: Mohon diingat bahwa intinya Anda sedang berkunjung ke situs kuburan massal, serta situs yang maknanya tak terhitung bagi sejumlah populasi dunia yang signifikan.\nEnglish:"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"id": "1",
|
8 |
+
"prompt": "Translate Indonesian to English.\nIndonesian: \"Mereka melakukan hal ini dengan cara memancarkan sebuah partikel kecil cahaya kecil yang biasa disebut \"foton\".\"\nEnglish: They do this by emitting a tiny particle of light called a \"photon\".\n\nTranslate Indonesian to English.\nIndonesian: Kami melewati waktu seperti rangkaian peristiwa yang berlalu dari masa depan hingga masa kini lalu ke masa lalu.\nEnglish: We experience time as a series of events passing from the future through the present to the past.\n\nTranslate Indonesian to English.\nIndonesian: Canyoning (atau: canyoneering) adalah segala aktivitas yang terjadi di dasar ngarai, yang kering atau penuh air.\nEnglish: Canyoning (or: canyoneering) is about going in a bottom of a canyon, which is either dry or full of water.\n\nTranslate Indonesian to English.\nIndonesian: Serangga adalah hewan pertama yang menjelajah angkasa. Kemampuan terbangnya membantu mereka menghindari musuh dengan lebih mudah dan mencari makanan dan pasangan dengan lebih efisien.\nEnglish:"
|
9 |
+
}
|
10 |
+
]
|
llama_cpp_requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
llama-cpp-python
|
mlx_requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
mlx
|
2 |
+
mlx-lm
|
multipurpose_chatbot/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
multipurpose_chatbot/__init__.py
ADDED
File without changes
|
multipurpose_chatbot/configs.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
|
4 |
+
# ! UI Markdown information
|
5 |
+
|
6 |
+
MODEL_TITLE = "<h1>Multi-Purpose Chatbot</h1>"
|
7 |
+
|
8 |
+
MODEL_DESC = f"""
|
9 |
+
<div style='display:flex; gap: 0.25rem; '>
|
10 |
+
<a href='https://github.com/DAMO-NLP-SG/Multipurpose-Chatbot'><img src='https://img.shields.io/badge/Github-Code-success'></a>
|
11 |
+
</div>
|
12 |
+
<span style="font-size: larger">
|
13 |
+
A multi-purpose helpful assistant with multiple functionalities (Chat, text-completion, RAG chat, batch inference).
|
14 |
+
</span>
|
15 |
+
""".strip()
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
MODEL_INFO = """
|
20 |
+
<h4>Model Name: {model_path}</h4>
|
21 |
+
"""
|
22 |
+
|
23 |
+
CITE_MARKDOWN = """
|
24 |
+
## Citation
|
25 |
+
If you find our project useful, hope you can star our repo and cite our repo as follows:
|
26 |
+
```
|
27 |
+
@article{multipurpose_chatbot_2024,
|
28 |
+
author = {Xuan-Phi Nguyen, },
|
29 |
+
title = {Multipurpose Chatbot},
|
30 |
+
year = 2024,
|
31 |
+
}
|
32 |
+
```
|
33 |
+
"""
|
34 |
+
|
35 |
+
USE_PANEL = bool(int(os.environ.get("USE_PANEL", "1")))
|
36 |
+
CHATBOT_HEIGHT = int(os.environ.get("CHATBOT_HEIGHT", "500"))
|
37 |
+
|
38 |
+
ALLOWED_PATHS = []
|
39 |
+
|
40 |
+
|
41 |
+
DEMOS = os.environ.get("DEMOS", "")
|
42 |
+
|
43 |
+
DEMOS = DEMOS.split(",") if DEMOS.strip() != "" else [
|
44 |
+
"DocChatInterfaceDemo",
|
45 |
+
"ChatInterfaceDemo",
|
46 |
+
"TextCompletionDemo",
|
47 |
+
# "RagChatInterfaceDemo",
|
48 |
+
# "VisionChatInterfaceDemo",
|
49 |
+
# "VisionDocChatInterfaceDemo",
|
50 |
+
]
|
51 |
+
|
52 |
+
# DEMOS=DocChatInterfaceDemo,ChatInterfaceDemo,RagChatInterfaceDemo,TextCompletionDemo
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
# ! server info
|
57 |
+
|
58 |
+
PORT = int(os.environ.get("PORT", "7860"))
|
59 |
+
PROXY = os.environ.get("PROXY", "").strip()
|
60 |
+
|
61 |
+
# ! backend info
|
62 |
+
|
63 |
+
BACKEND = os.environ.get("BACKEND", "debug")
|
64 |
+
|
65 |
+
# ! model information
|
66 |
+
# for RAG
|
67 |
+
RAG_EMBED_MODEL_NAME = os.environ.get("RAG_EMBED_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
|
68 |
+
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1024"))
|
69 |
+
CHUNK_OVERLAP = int(os.environ.get("CHUNK_SIZE", "50"))
|
70 |
+
|
71 |
+
|
72 |
+
SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", """You are a helpful, respectful, honest and safe AI assistant.""")
|
73 |
+
|
74 |
+
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
75 |
+
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.7"))
|
76 |
+
# ! these values currently not used
|
77 |
+
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.0"))
|
78 |
+
PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
|
79 |
+
|
80 |
+
|
81 |
+
# Transformers or vllm
|
82 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "teknium/OpenHermes-2.5-Mistral-7B")
|
83 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "Cool-Chatbot")
|
84 |
+
DTYPE = os.environ.get("DTYPE", "bfloat16")
|
85 |
+
DEVICE = os.environ.get("DEVICE", "cuda")
|
86 |
+
|
87 |
+
# VLLM
|
88 |
+
GPU_MEMORY_UTILIZATION = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9"))
|
89 |
+
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
90 |
+
QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
|
91 |
+
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
92 |
+
# how many iterations to perform safety check on response
|
93 |
+
STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
|
94 |
+
|
95 |
+
# llama.cpp
|
96 |
+
DEFAULT_CHAT_TEMPLATE = os.environ.get("DEFAULT_CHAT_TEMPLATE", "chatml")
|
97 |
+
N_CTX = int(os.environ.get("N_CTX", "4096"))
|
98 |
+
N_GPU_LAYERS = int(os.environ.get("N_GPU_LAYERS", "-1"))
|
99 |
+
|
100 |
+
# llava.llama.cpp
|
101 |
+
# ! pending development
|
102 |
+
|
103 |
+
# Multimodal
|
104 |
+
# IMAGE_TOKEN = os.environ.get("IMAGE_TOKEN", "[IMAGE]<|image|>[/IMAGE]")
|
105 |
+
IMAGE_TOKEN = os.environ.get("IMAGE_TOKEN", "<image>")
|
106 |
+
IMAGE_TOKEN_INTERACTIVE = bool(int(os.environ.get("IMAGE_TOKEN_INTERACTIVE", "0")))
|
107 |
+
# ! IMAGE_TOKEN_LENGTH expected embedding lengths of an image to calculate the actual tokens
|
108 |
+
IMAGE_TOKEN_LENGTH = int(os.environ.get("IMAGE_TOKEN_LENGTH", "576"))
|
109 |
+
# ! Llava1.6 to calculate the maximum number of patches in an image (max=5 for Llava1.6)
|
110 |
+
MAX_PACHES = int(os.environ.get("MAX_PACHES", "1"))
|
multipurpose_chatbot/demos/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
multipurpose_chatbot/demos/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .base_demo import *
|
3 |
+
|
4 |
+
from .chat_interface import ChatInterfaceDemo
|
5 |
+
from .rag_chat_interface import RagChatInterfaceDemo
|
6 |
+
from .multimodal_chat_interface import *
|
7 |
+
from .text_completion import *
|
8 |
+
from .batch_inference import *
|
multipurpose_chatbot/demos/base_demo.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
|
25 |
+
|
26 |
+
def create_class_func_registry():
|
27 |
+
registry = {}
|
28 |
+
def register_registry(cls, exist_ok=False):
|
29 |
+
assert exist_ok or cls.__name__ not in registry, f'{cls} already in registry: {registry}'
|
30 |
+
registry[cls.__name__] = cls
|
31 |
+
return cls
|
32 |
+
|
33 |
+
def get_registry(name):
|
34 |
+
assert name in registry, f'{name} not in registry: {registry}'
|
35 |
+
return registry[name]
|
36 |
+
|
37 |
+
return registry, register_registry, get_registry
|
38 |
+
|
39 |
+
DEMOS, register_demo, get_demo_class = create_class_func_registry()
|
40 |
+
|
41 |
+
|
42 |
+
class BaseDemo(object):
|
43 |
+
"""
|
44 |
+
All demo should be created from BaseDemo and registered with @register_demo
|
45 |
+
"""
|
46 |
+
def __init__(self) -> None:
|
47 |
+
pass
|
48 |
+
|
49 |
+
@property
|
50 |
+
def tab_name(self):
|
51 |
+
return "Demo"
|
52 |
+
|
53 |
+
def create_demo(
|
54 |
+
self,
|
55 |
+
title: Optional[str] = None,
|
56 |
+
description: Optional[str] = None,
|
57 |
+
**kwargs,
|
58 |
+
) -> gr.Blocks:
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
@document()
|
63 |
+
class CustomTabbedInterface(gr.Blocks):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
interface_list: list[gr.Interface],
|
67 |
+
tab_names: Optional[list[str]] = None,
|
68 |
+
title: Optional[str] = None,
|
69 |
+
description: Optional[str] = None,
|
70 |
+
theme: Optional[gr.Theme] = None,
|
71 |
+
analytics_enabled: Optional[bool] = None,
|
72 |
+
css: Optional[str] = None,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Parameters:
|
76 |
+
interface_list: a list of interfaces to be rendered in tabs.
|
77 |
+
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
78 |
+
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
79 |
+
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
80 |
+
css: custom css or path to custom css file to apply to entire Blocks
|
81 |
+
Returns:
|
82 |
+
a Gradio Tabbed Interface for the given interfaces
|
83 |
+
"""
|
84 |
+
super().__init__(
|
85 |
+
title=title or "Gradio",
|
86 |
+
theme=theme,
|
87 |
+
analytics_enabled=analytics_enabled,
|
88 |
+
mode="tabbed_interface",
|
89 |
+
css=css,
|
90 |
+
)
|
91 |
+
self.description = description
|
92 |
+
if tab_names is None:
|
93 |
+
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
94 |
+
with self:
|
95 |
+
if title:
|
96 |
+
gr.Markdown(
|
97 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
|
98 |
+
)
|
99 |
+
if description:
|
100 |
+
gr.Markdown(description)
|
101 |
+
with gr.Tabs():
|
102 |
+
for interface, tab_name in zip(interface_list, tab_names):
|
103 |
+
with gr.Tab(label=tab_name):
|
104 |
+
interface.render()
|
105 |
+
|
multipurpose_chatbot/demos/batch_inference.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
|
25 |
+
|
26 |
+
import inspect
|
27 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
28 |
+
|
29 |
+
import anyio
|
30 |
+
from gradio_client import utils as client_utils
|
31 |
+
from gradio_client.documentation import document
|
32 |
+
|
33 |
+
from gradio.blocks import Blocks
|
34 |
+
from gradio.components import (
|
35 |
+
Button,
|
36 |
+
Chatbot,
|
37 |
+
Component,
|
38 |
+
Markdown,
|
39 |
+
State,
|
40 |
+
Textbox,
|
41 |
+
get_component_instance,
|
42 |
+
)
|
43 |
+
from gradio.events import Dependency, on
|
44 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
45 |
+
from gradio.helpers import special_args
|
46 |
+
from gradio.layouts import Accordion, Group, Row
|
47 |
+
from gradio.routes import Request
|
48 |
+
from gradio.themes import ThemeClass as Theme
|
49 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
50 |
+
|
51 |
+
|
52 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
53 |
+
from ..configs import (
|
54 |
+
SYSTEM_PROMPT,
|
55 |
+
MODEL_NAME,
|
56 |
+
MAX_TOKENS,
|
57 |
+
TEMPERATURE,
|
58 |
+
USE_PANEL,
|
59 |
+
CHATBOT_HEIGHT,
|
60 |
+
)
|
61 |
+
|
62 |
+
from ..globals import MODEL_ENGINE
|
63 |
+
|
64 |
+
from .chat_interface import gradio_history_to_conversation_prompt
|
65 |
+
|
66 |
+
# Batch inference file upload
|
67 |
+
ENABLE_BATCH_INFER = bool(int(os.environ.get("ENABLE_BATCH_INFER", "1")))
|
68 |
+
BATCH_INFER_MAX_ITEMS = int(os.environ.get("BATCH_INFER_MAX_ITEMS", "100"))
|
69 |
+
BATCH_INFER_MAX_FILE_SIZE = int(os.environ.get("BATCH_INFER_MAX_FILE_SIZE", "500"))
|
70 |
+
BATCH_INFER_MAX_PROMPT_TOKENS = int(os.environ.get("BATCH_INFER_MAX_PROMPT_TOKENS", "4000"))
|
71 |
+
BATCH_INFER_SAVE_TMP_FILE = os.environ.get("BATCH_INFER_SAVE_TMP_FILE", "./tmp/pred.json")
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
|
76 |
+
each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
|
77 |
+
```
|
78 |
+
[ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
|
79 |
+
```
|
80 |
+
"""
|
81 |
+
|
82 |
+
def validate_file_item(filename, index, item: Dict[str, str]):
|
83 |
+
"""
|
84 |
+
check safety for items in files
|
85 |
+
"""
|
86 |
+
global MODEL_ENGINE
|
87 |
+
message = item['prompt'].strip()
|
88 |
+
|
89 |
+
if len(message) == 0:
|
90 |
+
raise gr.Error(f'Prompt {index} empty')
|
91 |
+
|
92 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(message))
|
93 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
94 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
95 |
+
|
96 |
+
|
97 |
+
def read_validate_json_files(files: Union[str, List[str]]):
|
98 |
+
files = files if isinstance(files, list) else [files]
|
99 |
+
filenames = [f.name for f in files]
|
100 |
+
all_items = []
|
101 |
+
for fname in filenames:
|
102 |
+
# check each files
|
103 |
+
print(f'Reading {fname}')
|
104 |
+
with open(fname, 'r', encoding='utf-8') as f:
|
105 |
+
items = json.load(f)
|
106 |
+
assert isinstance(items, list), f'Data {fname} not list'
|
107 |
+
assert all(isinstance(x, dict) for x in items), f'item in input file not list'
|
108 |
+
assert all("prompt" in x for x in items), f'key prompt should be in dict item of input file'
|
109 |
+
|
110 |
+
for i, x in enumerate(items):
|
111 |
+
validate_file_item(fname, i, x)
|
112 |
+
|
113 |
+
all_items.extend(items)
|
114 |
+
|
115 |
+
if len(all_items) > BATCH_INFER_MAX_ITEMS:
|
116 |
+
raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
|
117 |
+
|
118 |
+
return all_items, filenames
|
119 |
+
|
120 |
+
|
121 |
+
def remove_gradio_cache(exclude_names=None):
|
122 |
+
"""remove gradio cache to avoid flooding"""
|
123 |
+
import shutil
|
124 |
+
for root, dirs, files in os.walk('/tmp/gradio/'):
|
125 |
+
for f in files:
|
126 |
+
# if not any(f in ef for ef in except_files):
|
127 |
+
if exclude_names is None or not any(ef in f for ef in exclude_names):
|
128 |
+
print(f'Remove: {f}')
|
129 |
+
os.unlink(os.path.join(root, f))
|
130 |
+
|
131 |
+
|
132 |
+
def free_form_prompt(prompt, history=None, system_prompt=None):
|
133 |
+
return prompt
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def batch_inference_engine(
|
139 |
+
files: Union[str, List[str]],
|
140 |
+
prompt_mode: str,
|
141 |
+
temperature: float,
|
142 |
+
max_tokens: int,
|
143 |
+
stop_strings: str = "<s>,</s>,<|im_start|>",
|
144 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
145 |
+
):
|
146 |
+
global MODEL_ENGINE
|
147 |
+
temperature = float(temperature)
|
148 |
+
max_tokens = int(max_tokens)
|
149 |
+
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
150 |
+
|
151 |
+
all_items, filenames = read_validate_json_files(files)
|
152 |
+
|
153 |
+
# remove all items in /tmp/gradio/
|
154 |
+
remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
|
155 |
+
|
156 |
+
if prompt_mode == 'chat':
|
157 |
+
prompt_format_fn = gradio_history_to_conversation_prompt
|
158 |
+
elif prompt_mode == 'few-shot':
|
159 |
+
from functools import partial
|
160 |
+
prompt_format_fn = free_form_prompt
|
161 |
+
else:
|
162 |
+
raise gr.Error(f'Wrong mode {prompt_mode}')
|
163 |
+
|
164 |
+
full_prompts = [
|
165 |
+
prompt_format_fn(
|
166 |
+
x['prompt'], [], system_prompt=system_prompt
|
167 |
+
)
|
168 |
+
for i, x in enumerate(all_items)
|
169 |
+
]
|
170 |
+
print(f'{full_prompts[0]}\n')
|
171 |
+
|
172 |
+
full_num_tokens = [
|
173 |
+
len(MODEL_ENGINE.tokenizer.encode(p))
|
174 |
+
for p in full_prompts
|
175 |
+
]
|
176 |
+
if any(x >= MODEL_ENGINE.max_position_embeddings - 128 for x in full_num_tokens):
|
177 |
+
raise gr.Error(f"Some prompt is too long!")
|
178 |
+
|
179 |
+
# ! batch inference
|
180 |
+
responses = MODEL_ENGINE.batch_generate(
|
181 |
+
full_prompts,
|
182 |
+
temperature=temperature, max_tokens=max_tokens,
|
183 |
+
stop_strings=stop_strings,
|
184 |
+
)
|
185 |
+
|
186 |
+
if len(responses) != len(all_items):
|
187 |
+
raise gr.Error(f'inconsistent lengths {len(responses)} != {len(all_items)}')
|
188 |
+
|
189 |
+
for res, item in zip(responses, all_items):
|
190 |
+
item['response'] = res
|
191 |
+
|
192 |
+
save_path = BATCH_INFER_SAVE_TMP_FILE
|
193 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
194 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
195 |
+
json.dump(all_items, f, indent=4, ensure_ascii=False)
|
196 |
+
|
197 |
+
print_items = all_items[:2]
|
198 |
+
print(json.dumps(print_items, indent=4, ensure_ascii=False))
|
199 |
+
return save_path, print_items
|
200 |
+
|
201 |
+
|
202 |
+
class BatchInferenceDemo(BaseDemo):
|
203 |
+
def tab_name(self):
|
204 |
+
return "Batch Inference"
|
205 |
+
|
206 |
+
|
207 |
+
def create_demo(
|
208 |
+
self,
|
209 |
+
title: str | None = None,
|
210 |
+
description: str | None = None,
|
211 |
+
**kwargs
|
212 |
+
) -> gr.Blocks:
|
213 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
214 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
215 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
216 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
217 |
+
|
218 |
+
|
219 |
+
demo_file_upload = gr.Interface(
|
220 |
+
batch_inference_engine,
|
221 |
+
inputs=[
|
222 |
+
gr.File(file_count='single', file_types=['json']),
|
223 |
+
gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
|
224 |
+
gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
|
225 |
+
gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
|
226 |
+
# gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
|
227 |
+
# gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
|
228 |
+
gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
|
229 |
+
# gr.Number(value=0, label='current_time', visible=False),
|
230 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=4)
|
231 |
+
],
|
232 |
+
outputs=[
|
233 |
+
# "file",
|
234 |
+
gr.File(label="Generated file"),
|
235 |
+
# "json"
|
236 |
+
gr.JSON(label='Example outputs (display 2 samples)')
|
237 |
+
],
|
238 |
+
description=FILE_UPLOAD_DESCRIPTION,
|
239 |
+
allow_flagging=False,
|
240 |
+
examples=[
|
241 |
+
["upload_chat.json", "chat"],
|
242 |
+
["upload_few_shot.json", "few-shot"],
|
243 |
+
],
|
244 |
+
cache_examples=False,
|
245 |
+
)
|
246 |
+
return demo_file_upload
|
multipurpose_chatbot/demos/chat_interface.py
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
|
25 |
+
|
26 |
+
import inspect
|
27 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
28 |
+
|
29 |
+
import anyio
|
30 |
+
from gradio_client import utils as client_utils
|
31 |
+
from gradio_client.documentation import document
|
32 |
+
|
33 |
+
from gradio.blocks import Blocks
|
34 |
+
from gradio.components import (
|
35 |
+
Button,
|
36 |
+
Chatbot,
|
37 |
+
Component,
|
38 |
+
Markdown,
|
39 |
+
State,
|
40 |
+
Textbox,
|
41 |
+
get_component_instance,
|
42 |
+
)
|
43 |
+
from gradio.events import Dependency, on
|
44 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
45 |
+
from gradio.helpers import special_args
|
46 |
+
from gradio.layouts import Accordion, Group, Row
|
47 |
+
from gradio.routes import Request
|
48 |
+
from gradio.themes import ThemeClass as Theme
|
49 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
50 |
+
|
51 |
+
|
52 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
53 |
+
from ..configs import (
|
54 |
+
SYSTEM_PROMPT,
|
55 |
+
MODEL_NAME,
|
56 |
+
MAX_TOKENS,
|
57 |
+
TEMPERATURE,
|
58 |
+
USE_PANEL,
|
59 |
+
CHATBOT_HEIGHT,
|
60 |
+
)
|
61 |
+
|
62 |
+
from ..globals import MODEL_ENGINE
|
63 |
+
|
64 |
+
CHAT_EXAMPLES = [
|
65 |
+
["Explain general relativity."],
|
66 |
+
]
|
67 |
+
DATETIME_FORMAT = "Current date time: {cur_datetime}."
|
68 |
+
|
69 |
+
|
70 |
+
def gradio_history_to_openai_conversations(message=None, history=None, system_prompt=None):
|
71 |
+
conversations = []
|
72 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
73 |
+
if history is not None and len(history) > 0:
|
74 |
+
for i, (prompt, res) in enumerate(history):
|
75 |
+
if prompt is not None:
|
76 |
+
conversations.append({"role": "user", "content": prompt.strip()})
|
77 |
+
if res is not None:
|
78 |
+
conversations.append({"role": "assistant", "content": res.strip()})
|
79 |
+
if message is not None:
|
80 |
+
if len(message.strip()) == 0:
|
81 |
+
raise gr.Error("The message cannot be empty!")
|
82 |
+
conversations.append({"role": "user", "content": message.strip()})
|
83 |
+
if conversations[0]['role'] != 'system':
|
84 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
85 |
+
return conversations
|
86 |
+
|
87 |
+
|
88 |
+
def gradio_history_to_conversation_prompt(message=None, history=None, system_prompt=None):
|
89 |
+
global MODEL_ENGINE
|
90 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
91 |
+
gradio_history_to_openai_conversations(
|
92 |
+
message, history=history, system_prompt=system_prompt),
|
93 |
+
add_generation_prompt=True
|
94 |
+
)
|
95 |
+
return full_prompt
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
def get_datetime_string():
|
100 |
+
from datetime import datetime
|
101 |
+
now = datetime.now()
|
102 |
+
# dd/mm/YY H:M:S
|
103 |
+
dt_string = now.strftime("%B %d, %Y, %H:%M:%S")
|
104 |
+
return dt_string
|
105 |
+
|
106 |
+
|
107 |
+
def format_conversation(history, system_prompt=None):
|
108 |
+
_str = '\n'.join([
|
109 |
+
(
|
110 |
+
f'<<<User>>> {h[0]}\n'
|
111 |
+
f'<<<Asst>>> {h[1]}'
|
112 |
+
)
|
113 |
+
for h in history
|
114 |
+
])
|
115 |
+
_str = ""
|
116 |
+
for mes, res in history:
|
117 |
+
if mes is not None:
|
118 |
+
_str += f'<<<User>>> {mes}\n'
|
119 |
+
if res is not None:
|
120 |
+
_str += f'<<<Asst>>> {res}\n'
|
121 |
+
if system_prompt is not None:
|
122 |
+
_str = f"<<<Syst>>> {system_prompt}\n" + _str
|
123 |
+
return _str
|
124 |
+
|
125 |
+
|
126 |
+
def chat_response_stream_multiturn_engine(
|
127 |
+
message: str,
|
128 |
+
history: List[Tuple[str, str]],
|
129 |
+
temperature: float,
|
130 |
+
max_tokens: int,
|
131 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
132 |
+
):
|
133 |
+
global MODEL_ENGINE
|
134 |
+
temperature = float(temperature)
|
135 |
+
# ! remove frequency_penalty
|
136 |
+
# frequency_penalty = float(frequency_penalty)
|
137 |
+
max_tokens = int(max_tokens)
|
138 |
+
message = message.strip()
|
139 |
+
if len(message) == 0:
|
140 |
+
raise gr.Error("The message cannot be empty!")
|
141 |
+
# ! skip safety
|
142 |
+
if DATETIME_FORMAT in system_prompt:
|
143 |
+
# ! This sometime works sometimes dont
|
144 |
+
system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
|
145 |
+
full_prompt = gradio_history_to_conversation_prompt(message.strip(), history=history, system_prompt=system_prompt)
|
146 |
+
# ! length checked
|
147 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
|
148 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
149 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
150 |
+
print(full_prompt)
|
151 |
+
outputs = None
|
152 |
+
response = None
|
153 |
+
num_tokens = -1
|
154 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
155 |
+
prompt=full_prompt,
|
156 |
+
temperature=temperature,
|
157 |
+
max_tokens=max_tokens,
|
158 |
+
)):
|
159 |
+
if isinstance(outputs, tuple):
|
160 |
+
response, num_tokens = outputs
|
161 |
+
else:
|
162 |
+
response, num_tokens = outputs, -1
|
163 |
+
yield response, num_tokens
|
164 |
+
|
165 |
+
print(format_conversation(history + [[message, response]]))
|
166 |
+
|
167 |
+
if response is not None:
|
168 |
+
yield response, num_tokens
|
169 |
+
|
170 |
+
|
171 |
+
class CustomizedChatInterface(gr.ChatInterface):
|
172 |
+
"""
|
173 |
+
Fixing some issue with chatinterace
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
fn: Callable,
|
179 |
+
*,
|
180 |
+
chatbot: Chatbot | None = None,
|
181 |
+
textbox: Textbox | None = None,
|
182 |
+
additional_inputs: str | Component | list[str | Component] | None = None,
|
183 |
+
additional_inputs_accordion_name: str | None = None,
|
184 |
+
additional_inputs_accordion: str | Accordion | None = None,
|
185 |
+
examples: list[str] | None = None,
|
186 |
+
cache_examples: bool | None = None,
|
187 |
+
title: str | None = None,
|
188 |
+
description: str | None = None,
|
189 |
+
theme: Theme | str | None = None,
|
190 |
+
css: str | None = None,
|
191 |
+
js: str | None = None,
|
192 |
+
head: str | None = None,
|
193 |
+
analytics_enabled: bool | None = None,
|
194 |
+
submit_btn: str | None | Button = "Submit",
|
195 |
+
stop_btn: str | None | Button = "Stop",
|
196 |
+
retry_btn: str | None | Button = "🔄 Retry",
|
197 |
+
undo_btn: str | None | Button = "↩️ Undo",
|
198 |
+
clear_btn: str | None | Button = "🗑️ Clear",
|
199 |
+
autofocus: bool = True,
|
200 |
+
concurrency_limit: int | None | Literal["default"] = "default",
|
201 |
+
fill_height: bool = True,
|
202 |
+
):
|
203 |
+
"""
|
204 |
+
Parameters:
|
205 |
+
fn: The function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
|
206 |
+
chatbot: An instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
|
207 |
+
textbox: An instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
|
208 |
+
additional_inputs: An instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
|
209 |
+
additional_inputs_accordion_name: Deprecated. Will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead.
|
210 |
+
additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
|
211 |
+
examples: Sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
|
212 |
+
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
213 |
+
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
|
214 |
+
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
|
215 |
+
theme: Theme to use, loaded from gradio.themes.
|
216 |
+
css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
|
217 |
+
js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
|
218 |
+
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
|
219 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
|
220 |
+
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
|
221 |
+
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
|
222 |
+
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
|
223 |
+
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
|
224 |
+
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
|
225 |
+
autofocus: If True, autofocuses to the textbox when the page loads.
|
226 |
+
concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
|
227 |
+
fill_height: If True, the chat interface will expand to the height of window.
|
228 |
+
"""
|
229 |
+
try:
|
230 |
+
super(gr.ChatInterface, self).__init__(
|
231 |
+
analytics_enabled=analytics_enabled,
|
232 |
+
mode="chat_interface",
|
233 |
+
css=css,
|
234 |
+
title=title or "Gradio",
|
235 |
+
theme=theme,
|
236 |
+
js=js,
|
237 |
+
head=head,
|
238 |
+
fill_height=fill_height,
|
239 |
+
)
|
240 |
+
except Exception as e:
|
241 |
+
# Handling some old gradio version with out fill_height
|
242 |
+
super(gr.ChatInterface, self).__init__(
|
243 |
+
analytics_enabled=analytics_enabled,
|
244 |
+
mode="chat_interface",
|
245 |
+
css=css,
|
246 |
+
title=title or "Gradio",
|
247 |
+
theme=theme,
|
248 |
+
js=js,
|
249 |
+
head=head,
|
250 |
+
# fill_height=fill_height,
|
251 |
+
)
|
252 |
+
self.concurrency_limit = concurrency_limit
|
253 |
+
self.fn = fn
|
254 |
+
self.is_async = inspect.iscoroutinefunction(
|
255 |
+
self.fn
|
256 |
+
) or inspect.isasyncgenfunction(self.fn)
|
257 |
+
self.is_generator = inspect.isgeneratorfunction(
|
258 |
+
self.fn
|
259 |
+
) or inspect.isasyncgenfunction(self.fn)
|
260 |
+
self.examples = examples
|
261 |
+
if self.space_id and cache_examples is None:
|
262 |
+
self.cache_examples = True
|
263 |
+
else:
|
264 |
+
self.cache_examples = cache_examples or False
|
265 |
+
self.buttons: list[Button | None] = []
|
266 |
+
|
267 |
+
if additional_inputs:
|
268 |
+
if not isinstance(additional_inputs, list):
|
269 |
+
additional_inputs = [additional_inputs]
|
270 |
+
self.additional_inputs = [
|
271 |
+
get_component_instance(i)
|
272 |
+
for i in additional_inputs # type: ignore
|
273 |
+
]
|
274 |
+
else:
|
275 |
+
self.additional_inputs = []
|
276 |
+
if additional_inputs_accordion_name is not None:
|
277 |
+
print(
|
278 |
+
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
279 |
+
)
|
280 |
+
self.additional_inputs_accordion_params = {
|
281 |
+
"label": additional_inputs_accordion_name
|
282 |
+
}
|
283 |
+
if additional_inputs_accordion is None:
|
284 |
+
self.additional_inputs_accordion_params = {
|
285 |
+
"label": "Additional Inputs",
|
286 |
+
"open": False,
|
287 |
+
}
|
288 |
+
elif isinstance(additional_inputs_accordion, str):
|
289 |
+
self.additional_inputs_accordion_params = {
|
290 |
+
"label": additional_inputs_accordion
|
291 |
+
}
|
292 |
+
elif isinstance(additional_inputs_accordion, Accordion):
|
293 |
+
self.additional_inputs_accordion_params = (
|
294 |
+
additional_inputs_accordion.recover_kwargs(
|
295 |
+
additional_inputs_accordion.get_config()
|
296 |
+
)
|
297 |
+
)
|
298 |
+
else:
|
299 |
+
raise ValueError(
|
300 |
+
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
|
301 |
+
)
|
302 |
+
|
303 |
+
with self:
|
304 |
+
if title:
|
305 |
+
Markdown(
|
306 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
307 |
+
)
|
308 |
+
if description:
|
309 |
+
Markdown(description)
|
310 |
+
|
311 |
+
if chatbot:
|
312 |
+
self.chatbot = chatbot.render()
|
313 |
+
else:
|
314 |
+
self.chatbot = Chatbot(
|
315 |
+
label="Chatbot", scale=1, height=200 if fill_height else None
|
316 |
+
)
|
317 |
+
|
318 |
+
with Row():
|
319 |
+
for btn in [retry_btn, undo_btn, clear_btn]:
|
320 |
+
if btn is not None:
|
321 |
+
if isinstance(btn, Button):
|
322 |
+
btn.render()
|
323 |
+
elif isinstance(btn, str):
|
324 |
+
btn = Button(btn, variant="secondary", size="sm")
|
325 |
+
else:
|
326 |
+
raise ValueError(
|
327 |
+
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
328 |
+
)
|
329 |
+
self.buttons.append(btn) # type: ignore
|
330 |
+
|
331 |
+
with Group():
|
332 |
+
with Row():
|
333 |
+
if textbox:
|
334 |
+
textbox.container = False
|
335 |
+
textbox.show_label = False
|
336 |
+
textbox_ = textbox.render()
|
337 |
+
assert isinstance(textbox_, Textbox)
|
338 |
+
self.textbox = textbox_
|
339 |
+
else:
|
340 |
+
self.textbox = Textbox(
|
341 |
+
container=False,
|
342 |
+
show_label=False,
|
343 |
+
label="Message",
|
344 |
+
placeholder="Type a message...",
|
345 |
+
scale=7,
|
346 |
+
autofocus=autofocus,
|
347 |
+
)
|
348 |
+
if submit_btn is not None:
|
349 |
+
if isinstance(submit_btn, Button):
|
350 |
+
submit_btn.render()
|
351 |
+
elif isinstance(submit_btn, str):
|
352 |
+
submit_btn = Button(
|
353 |
+
submit_btn,
|
354 |
+
variant="primary",
|
355 |
+
scale=2,
|
356 |
+
min_width=150,
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
raise ValueError(
|
360 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
361 |
+
)
|
362 |
+
if stop_btn is not None:
|
363 |
+
if isinstance(stop_btn, Button):
|
364 |
+
stop_btn.visible = False
|
365 |
+
stop_btn.render()
|
366 |
+
elif isinstance(stop_btn, str):
|
367 |
+
stop_btn = Button(
|
368 |
+
stop_btn,
|
369 |
+
variant="stop",
|
370 |
+
visible=False,
|
371 |
+
scale=2,
|
372 |
+
min_width=150,
|
373 |
+
)
|
374 |
+
else:
|
375 |
+
raise ValueError(
|
376 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
377 |
+
)
|
378 |
+
self.num_tokens = Textbox(
|
379 |
+
container=False,
|
380 |
+
show_label=False,
|
381 |
+
label="num_tokens",
|
382 |
+
placeholder="0 tokens",
|
383 |
+
scale=1,
|
384 |
+
interactive=False,
|
385 |
+
# autofocus=autofocus,
|
386 |
+
min_width=10
|
387 |
+
)
|
388 |
+
self.buttons.extend([submit_btn, stop_btn]) # type: ignore
|
389 |
+
|
390 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
391 |
+
self.fake_response_textbox = Textbox(label="Response", visible=False)
|
392 |
+
(
|
393 |
+
self.retry_btn,
|
394 |
+
self.undo_btn,
|
395 |
+
self.clear_btn,
|
396 |
+
self.submit_btn,
|
397 |
+
self.stop_btn,
|
398 |
+
) = self.buttons
|
399 |
+
|
400 |
+
if examples:
|
401 |
+
if self.is_generator:
|
402 |
+
examples_fn = self._examples_stream_fn
|
403 |
+
else:
|
404 |
+
examples_fn = self._examples_fn
|
405 |
+
|
406 |
+
self.examples_handler = Examples(
|
407 |
+
examples=examples,
|
408 |
+
inputs=[self.textbox] + self.additional_inputs,
|
409 |
+
outputs=self.chatbot,
|
410 |
+
fn=examples_fn,
|
411 |
+
)
|
412 |
+
|
413 |
+
any_unrendered_inputs = any(
|
414 |
+
not inp.is_rendered for inp in self.additional_inputs
|
415 |
+
)
|
416 |
+
if self.additional_inputs and any_unrendered_inputs:
|
417 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
418 |
+
for input_component in self.additional_inputs:
|
419 |
+
if not input_component.is_rendered:
|
420 |
+
input_component.render()
|
421 |
+
|
422 |
+
# The example caching must happen after the input components have rendered
|
423 |
+
if cache_examples:
|
424 |
+
client_utils.synchronize_async(self.examples_handler.cache)
|
425 |
+
|
426 |
+
self.saved_input = State()
|
427 |
+
self.chatbot_state = (
|
428 |
+
State(self.chatbot.value) if self.chatbot.value else State([])
|
429 |
+
)
|
430 |
+
|
431 |
+
self._setup_events()
|
432 |
+
self._setup_api()
|
433 |
+
|
434 |
+
# replace events so that submit button is disabled during generation, if stop_btn not found
|
435 |
+
# this prevent weird behavior
|
436 |
+
def _setup_stop_events(
|
437 |
+
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
438 |
+
) -> None:
|
439 |
+
from gradio.components import State
|
440 |
+
event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
|
441 |
+
if self.stop_btn and self.is_generator:
|
442 |
+
if self.submit_btn:
|
443 |
+
for event_trigger in event_triggers:
|
444 |
+
event_trigger(
|
445 |
+
lambda: (
|
446 |
+
Button(visible=False),
|
447 |
+
Button(visible=True),
|
448 |
+
),
|
449 |
+
None,
|
450 |
+
[self.submit_btn, self.stop_btn],
|
451 |
+
api_name=False,
|
452 |
+
queue=False,
|
453 |
+
)
|
454 |
+
event_to_cancel.then(
|
455 |
+
lambda: (Button(visible=True), Button(visible=False)),
|
456 |
+
None,
|
457 |
+
[self.submit_btn, self.stop_btn],
|
458 |
+
api_name=False,
|
459 |
+
queue=False,
|
460 |
+
)
|
461 |
+
else:
|
462 |
+
for event_trigger in event_triggers:
|
463 |
+
event_trigger(
|
464 |
+
lambda: Button(visible=True),
|
465 |
+
None,
|
466 |
+
[self.stop_btn],
|
467 |
+
api_name=False,
|
468 |
+
queue=False,
|
469 |
+
)
|
470 |
+
event_to_cancel.then(
|
471 |
+
lambda: Button(visible=False),
|
472 |
+
None,
|
473 |
+
[self.stop_btn],
|
474 |
+
api_name=False,
|
475 |
+
queue=False,
|
476 |
+
)
|
477 |
+
self.stop_btn.click(
|
478 |
+
None,
|
479 |
+
None,
|
480 |
+
None,
|
481 |
+
cancels=event_to_cancel,
|
482 |
+
api_name=False,
|
483 |
+
)
|
484 |
+
else:
|
485 |
+
if self.submit_btn:
|
486 |
+
for event_trigger in event_triggers:
|
487 |
+
event_trigger(
|
488 |
+
lambda: Button(interactive=False),
|
489 |
+
None,
|
490 |
+
[self.submit_btn],
|
491 |
+
api_name=False,
|
492 |
+
queue=False,
|
493 |
+
)
|
494 |
+
event_to_cancel.then(
|
495 |
+
lambda: Button(interactive=True),
|
496 |
+
None,
|
497 |
+
[self.submit_btn],
|
498 |
+
api_name=False,
|
499 |
+
queue=False,
|
500 |
+
)
|
501 |
+
# upon clear, cancel the submit event as well
|
502 |
+
if self.clear_btn:
|
503 |
+
self.clear_btn.click(
|
504 |
+
lambda: ([], [], None, Button(interactive=True)),
|
505 |
+
None,
|
506 |
+
[self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
|
507 |
+
queue=False,
|
508 |
+
api_name=False,
|
509 |
+
cancels=event_to_cancel,
|
510 |
+
)
|
511 |
+
|
512 |
+
def _setup_events(self) -> None:
|
513 |
+
from gradio.components import State
|
514 |
+
has_on = False
|
515 |
+
try:
|
516 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
517 |
+
has_on = True
|
518 |
+
except ImportError as ie:
|
519 |
+
has_on = False
|
520 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
521 |
+
if not self.is_generator:
|
522 |
+
raise NotImplementedError(f'should use generator')
|
523 |
+
|
524 |
+
if has_on:
|
525 |
+
# new version
|
526 |
+
submit_triggers = (
|
527 |
+
[self.textbox.submit, self.submit_btn.click]
|
528 |
+
if self.submit_btn
|
529 |
+
else [self.textbox.submit]
|
530 |
+
)
|
531 |
+
submit_event = (
|
532 |
+
on(
|
533 |
+
submit_triggers,
|
534 |
+
self._clear_and_save_textbox,
|
535 |
+
[self.textbox],
|
536 |
+
[self.textbox, self.saved_input],
|
537 |
+
api_name=False,
|
538 |
+
queue=False,
|
539 |
+
)
|
540 |
+
.then(
|
541 |
+
self._display_input,
|
542 |
+
[self.saved_input, self.chatbot_state],
|
543 |
+
[self.chatbot, self.chatbot_state],
|
544 |
+
api_name=False,
|
545 |
+
queue=False,
|
546 |
+
)
|
547 |
+
.then(
|
548 |
+
submit_fn,
|
549 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
550 |
+
[self.chatbot, self.chatbot_state, self.num_tokens],
|
551 |
+
api_name=False,
|
552 |
+
)
|
553 |
+
)
|
554 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
555 |
+
else:
|
556 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
557 |
+
|
558 |
+
if self.retry_btn:
|
559 |
+
retry_event = (
|
560 |
+
self.retry_btn.click(
|
561 |
+
self._delete_prev_fn,
|
562 |
+
[self.chatbot_state],
|
563 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
564 |
+
api_name=False,
|
565 |
+
queue=False,
|
566 |
+
)
|
567 |
+
.then(
|
568 |
+
self._display_input,
|
569 |
+
[self.saved_input, self.chatbot_state],
|
570 |
+
[self.chatbot, self.chatbot_state],
|
571 |
+
api_name=False,
|
572 |
+
queue=False,
|
573 |
+
)
|
574 |
+
.then(
|
575 |
+
submit_fn,
|
576 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
577 |
+
[self.chatbot, self.chatbot_state, self.num_tokens],
|
578 |
+
api_name=False,
|
579 |
+
)
|
580 |
+
)
|
581 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
582 |
+
|
583 |
+
if self.undo_btn:
|
584 |
+
self.undo_btn.click(
|
585 |
+
self._delete_prev_fn,
|
586 |
+
[self.chatbot_state],
|
587 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
588 |
+
api_name=False,
|
589 |
+
queue=False,
|
590 |
+
).then(
|
591 |
+
lambda x: x,
|
592 |
+
[self.saved_input],
|
593 |
+
[self.textbox],
|
594 |
+
api_name=False,
|
595 |
+
queue=False,
|
596 |
+
)
|
597 |
+
# Reconfigure clear_btn to stop and clear text box
|
598 |
+
|
599 |
+
def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:
|
600 |
+
return "", message
|
601 |
+
|
602 |
+
def _display_input(
|
603 |
+
self, message: str, history: List[List[Union[str, None]]]
|
604 |
+
) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
|
605 |
+
if message is not None and message.strip() != "":
|
606 |
+
history.append([message, None])
|
607 |
+
return history, history
|
608 |
+
|
609 |
+
async def _stream_fn(
|
610 |
+
self,
|
611 |
+
message: str,
|
612 |
+
history_with_input,
|
613 |
+
request: Request,
|
614 |
+
*args,
|
615 |
+
) -> AsyncGenerator:
|
616 |
+
history = history_with_input[:-1]
|
617 |
+
inputs, _, _ = special_args(
|
618 |
+
self.fn, inputs=[message, history, *args], request=request
|
619 |
+
)
|
620 |
+
|
621 |
+
if self.is_async:
|
622 |
+
generator = self.fn(*inputs)
|
623 |
+
else:
|
624 |
+
generator = await anyio.to_thread.run_sync(
|
625 |
+
self.fn, *inputs, limiter=self.limiter
|
626 |
+
)
|
627 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
628 |
+
|
629 |
+
# ! In case of error, yield the previous history & undo any generation before raising error
|
630 |
+
try:
|
631 |
+
first_response_pack = await async_iteration(generator)
|
632 |
+
if isinstance(first_response_pack, (tuple, list)):
|
633 |
+
first_response, num_tokens = first_response_pack
|
634 |
+
else:
|
635 |
+
first_response, num_tokens = first_response_pack, -1
|
636 |
+
update = history + [[message, first_response]]
|
637 |
+
yield update, update, f"{num_tokens} toks"
|
638 |
+
except StopIteration:
|
639 |
+
update = history + [[message, None]]
|
640 |
+
yield update, update, "NaN toks"
|
641 |
+
except Exception as e:
|
642 |
+
yield history, history, "NaN toks"
|
643 |
+
raise e
|
644 |
+
|
645 |
+
try:
|
646 |
+
async for response_pack in generator:
|
647 |
+
if isinstance(response_pack, (tuple, list)):
|
648 |
+
response, num_tokens = response_pack
|
649 |
+
else:
|
650 |
+
response, num_tokens = response_pack, "NaN toks"
|
651 |
+
update = history + [[message, response]]
|
652 |
+
yield update, update, f"{num_tokens} toks"
|
653 |
+
except Exception as e:
|
654 |
+
yield history, history, "NaN toks"
|
655 |
+
raise e
|
656 |
+
|
657 |
+
@register_demo
|
658 |
+
class ChatInterfaceDemo(BaseDemo):
|
659 |
+
@property
|
660 |
+
def tab_name(self):
|
661 |
+
return "Chat"
|
662 |
+
|
663 |
+
def create_demo(
|
664 |
+
self,
|
665 |
+
title: str | None = None,
|
666 |
+
description: str | None = None,
|
667 |
+
**kwargs
|
668 |
+
) -> gr.Blocks:
|
669 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
670 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
671 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
672 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
673 |
+
# frequence_penalty = FREQUENCE_PENALTY
|
674 |
+
# presence_penalty = PRESENCE_PENALTY
|
675 |
+
|
676 |
+
demo_chat = CustomizedChatInterface(
|
677 |
+
chat_response_stream_multiturn_engine,
|
678 |
+
chatbot=gr.Chatbot(
|
679 |
+
label=model_name,
|
680 |
+
bubble_full_width=False,
|
681 |
+
latex_delimiters=[
|
682 |
+
{ "left": "$", "right": "$", "display": False},
|
683 |
+
{ "left": "$$", "right": "$$", "display": True},
|
684 |
+
],
|
685 |
+
show_copy_button=True,
|
686 |
+
layout="panel" if USE_PANEL else "bubble",
|
687 |
+
height=CHATBOT_HEIGHT,
|
688 |
+
),
|
689 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
690 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
691 |
+
title=title,
|
692 |
+
description=description,
|
693 |
+
additional_inputs=[
|
694 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
695 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
696 |
+
# gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
697 |
+
# gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
698 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=4)
|
699 |
+
],
|
700 |
+
examples=CHAT_EXAMPLES,
|
701 |
+
cache_examples=False
|
702 |
+
)
|
703 |
+
return demo_chat
|
704 |
+
|
multipurpose_chatbot/demos/multimodal_chat_interface.py
ADDED
@@ -0,0 +1,1293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
from gradio.components.base import Component
|
25 |
+
|
26 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
27 |
+
|
28 |
+
|
29 |
+
from .chat_interface import (
|
30 |
+
SYSTEM_PROMPT,
|
31 |
+
MODEL_NAME,
|
32 |
+
MAX_TOKENS,
|
33 |
+
TEMPERATURE,
|
34 |
+
CHAT_EXAMPLES,
|
35 |
+
format_conversation,
|
36 |
+
gradio_history_to_openai_conversations,
|
37 |
+
gradio_history_to_conversation_prompt,
|
38 |
+
DATETIME_FORMAT,
|
39 |
+
get_datetime_string,
|
40 |
+
chat_response_stream_multiturn_engine,
|
41 |
+
ChatInterfaceDemo,
|
42 |
+
CustomizedChatInterface,
|
43 |
+
)
|
44 |
+
|
45 |
+
from gradio.events import Events
|
46 |
+
|
47 |
+
import inspect
|
48 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
49 |
+
|
50 |
+
import anyio
|
51 |
+
from gradio_client import utils as client_utils
|
52 |
+
from gradio_client.documentation import document
|
53 |
+
|
54 |
+
from gradio.blocks import Blocks
|
55 |
+
from gradio.components import (
|
56 |
+
Button,
|
57 |
+
Chatbot,
|
58 |
+
Component,
|
59 |
+
Markdown,
|
60 |
+
State,
|
61 |
+
Textbox,
|
62 |
+
get_component_instance,
|
63 |
+
)
|
64 |
+
from gradio.events import Dependency, on
|
65 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
66 |
+
from gradio.helpers import special_args
|
67 |
+
from gradio.layouts import Accordion, Group, Row
|
68 |
+
from gradio.routes import Request
|
69 |
+
from gradio.themes import ThemeClass as Theme
|
70 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
71 |
+
|
72 |
+
from ..globals import MODEL_ENGINE
|
73 |
+
|
74 |
+
from ..configs import (
|
75 |
+
USE_PANEL,
|
76 |
+
IMAGE_TOKEN,
|
77 |
+
IMAGE_TOKEN_INTERACTIVE,
|
78 |
+
CHATBOT_HEIGHT,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
CSS = """
|
84 |
+
.message-fit {
|
85 |
+
min-width: 20em;
|
86 |
+
width: fit-content !important;
|
87 |
+
}
|
88 |
+
|
89 |
+
.message.svelte-1lcyrx4.svelte-1lcyrx4.svelte-1lcyrx4 {
|
90 |
+
padding-top: 1em;
|
91 |
+
padding-bottom: 1em;
|
92 |
+
}
|
93 |
+
"""
|
94 |
+
|
95 |
+
|
96 |
+
DOC_TEMPLATE = """###
|
97 |
+
{content}
|
98 |
+
###
|
99 |
+
|
100 |
+
"""
|
101 |
+
|
102 |
+
DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \
|
103 |
+
If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
|
104 |
+
"""
|
105 |
+
|
106 |
+
|
107 |
+
def undo_history(history):
|
108 |
+
if len(history) == 0:
|
109 |
+
return history
|
110 |
+
if history[-1][-1] is not None:
|
111 |
+
if history[-1][0] is not None:
|
112 |
+
history[-1][-1] = None
|
113 |
+
else:
|
114 |
+
history = history[:-1]
|
115 |
+
else:
|
116 |
+
history = history[:-1]
|
117 |
+
return history
|
118 |
+
|
119 |
+
|
120 |
+
def undo_history_until_last_assistant_turn(history):
|
121 |
+
history = undo_history(history)
|
122 |
+
while len(history) > 0 and history[-1][-1] is None:
|
123 |
+
history = undo_history(history)
|
124 |
+
return history, history
|
125 |
+
|
126 |
+
|
127 |
+
class MultiModalChatInterface(CustomizedChatInterface):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
fn: Callable,
|
131 |
+
*,
|
132 |
+
chatbot: Chatbot | None = None,
|
133 |
+
textbox: Textbox | None = None,
|
134 |
+
additional_inputs: str | Component | list[str | Component] | None = None,
|
135 |
+
additional_inputs_accordion_name: str | None = None,
|
136 |
+
additional_inputs_accordion: str | Accordion | None = None,
|
137 |
+
add_multimodal_fn: Callable | None = None,
|
138 |
+
render_additional_inputs_fn: Callable | None = None,
|
139 |
+
examples: list[str] | None = None,
|
140 |
+
cache_examples: bool | None = None,
|
141 |
+
title: str | None = None,
|
142 |
+
description: str | None = None,
|
143 |
+
theme: Theme | str | None = None,
|
144 |
+
css: str | None = None,
|
145 |
+
js: str | None = None,
|
146 |
+
head: str | None = None,
|
147 |
+
analytics_enabled: bool | None = None,
|
148 |
+
submit_btn: str | None | Button = "Submit",
|
149 |
+
stop_btn: str | None | Button = "Stop",
|
150 |
+
retry_btn: str | None | Button = "🔄 Retry",
|
151 |
+
undo_btn: str | None | Button = "↩️ Undo",
|
152 |
+
clear_btn: str | None | Button = "🗑️ Clear",
|
153 |
+
autofocus: bool = True,
|
154 |
+
concurrency_limit: int | None | Literal["default"] = "default",
|
155 |
+
fill_height: bool = True,
|
156 |
+
):
|
157 |
+
"""
|
158 |
+
Parameters:
|
159 |
+
fn: The function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
|
160 |
+
chatbot: An instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
|
161 |
+
textbox: An instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
|
162 |
+
additional_inputs: An instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
|
163 |
+
additional_inputs_accordion_name: Deprecated. Will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead.
|
164 |
+
additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
|
165 |
+
examples: Sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
|
166 |
+
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
167 |
+
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
|
168 |
+
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
|
169 |
+
theme: Theme to use, loaded from gradio.themes.
|
170 |
+
css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
|
171 |
+
js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
|
172 |
+
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
|
173 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
|
174 |
+
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
|
175 |
+
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
|
176 |
+
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
|
177 |
+
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
|
178 |
+
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
|
179 |
+
autofocus: If True, autofocuses to the textbox when the page loads.
|
180 |
+
concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
|
181 |
+
fill_height: If True, the chat interface will expand to the height of window.
|
182 |
+
"""
|
183 |
+
try:
|
184 |
+
super(gr.ChatInterface, self).__init__(
|
185 |
+
analytics_enabled=analytics_enabled,
|
186 |
+
mode="chat_interface",
|
187 |
+
css=css,
|
188 |
+
title=title or "Gradio",
|
189 |
+
theme=theme,
|
190 |
+
js=js,
|
191 |
+
head=head,
|
192 |
+
fill_height=fill_height,
|
193 |
+
)
|
194 |
+
except Exception as e:
|
195 |
+
# Handle old gradio versions without fill_height
|
196 |
+
super(gr.ChatInterface, self).__init__(
|
197 |
+
analytics_enabled=analytics_enabled,
|
198 |
+
mode="chat_interface",
|
199 |
+
css=css,
|
200 |
+
title=title or "Gradio",
|
201 |
+
theme=theme,
|
202 |
+
js=js,
|
203 |
+
head=head,
|
204 |
+
# fill_height=fill_height,
|
205 |
+
)
|
206 |
+
|
207 |
+
self.concurrency_limit = concurrency_limit
|
208 |
+
self.fn = fn
|
209 |
+
self.add_multimodal_fn = add_multimodal_fn
|
210 |
+
self.render_additional_inputs_fn = render_additional_inputs_fn
|
211 |
+
self.multimodal_inputs = []
|
212 |
+
self.is_async = inspect.iscoroutinefunction(
|
213 |
+
self.fn
|
214 |
+
) or inspect.isasyncgenfunction(self.fn)
|
215 |
+
self.is_generator = inspect.isgeneratorfunction(
|
216 |
+
self.fn
|
217 |
+
) or inspect.isasyncgenfunction(self.fn)
|
218 |
+
self.examples = examples
|
219 |
+
if self.space_id and cache_examples is None:
|
220 |
+
self.cache_examples = True
|
221 |
+
else:
|
222 |
+
self.cache_examples = cache_examples or False
|
223 |
+
self.buttons: list[Button | None] = []
|
224 |
+
|
225 |
+
if additional_inputs:
|
226 |
+
if not isinstance(additional_inputs, list):
|
227 |
+
additional_inputs = [additional_inputs]
|
228 |
+
self.additional_inputs = [
|
229 |
+
get_component_instance(i)
|
230 |
+
for i in additional_inputs # type: ignore
|
231 |
+
]
|
232 |
+
else:
|
233 |
+
self.additional_inputs = []
|
234 |
+
if additional_inputs_accordion_name is not None:
|
235 |
+
print(
|
236 |
+
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
237 |
+
)
|
238 |
+
self.additional_inputs_accordion_params = {
|
239 |
+
"label": additional_inputs_accordion_name
|
240 |
+
}
|
241 |
+
if additional_inputs_accordion is None:
|
242 |
+
self.additional_inputs_accordion_params = {
|
243 |
+
"label": "Additional Inputs",
|
244 |
+
"open": False,
|
245 |
+
}
|
246 |
+
elif isinstance(additional_inputs_accordion, str):
|
247 |
+
self.additional_inputs_accordion_params = {
|
248 |
+
"label": additional_inputs_accordion
|
249 |
+
}
|
250 |
+
elif isinstance(additional_inputs_accordion, Accordion):
|
251 |
+
self.additional_inputs_accordion_params = (
|
252 |
+
additional_inputs_accordion.recover_kwargs(
|
253 |
+
additional_inputs_accordion.get_config()
|
254 |
+
)
|
255 |
+
)
|
256 |
+
else:
|
257 |
+
raise ValueError(
|
258 |
+
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
|
259 |
+
)
|
260 |
+
|
261 |
+
with self:
|
262 |
+
if title:
|
263 |
+
Markdown(
|
264 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
265 |
+
)
|
266 |
+
if description:
|
267 |
+
Markdown(description)
|
268 |
+
|
269 |
+
if chatbot:
|
270 |
+
self.chatbot = chatbot.render()
|
271 |
+
else:
|
272 |
+
self.chatbot = Chatbot(
|
273 |
+
label="Chatbot", scale=1, height=200 if fill_height else None
|
274 |
+
)
|
275 |
+
|
276 |
+
with Row():
|
277 |
+
for btn in [retry_btn, undo_btn, clear_btn]:
|
278 |
+
if btn is not None:
|
279 |
+
if isinstance(btn, Button):
|
280 |
+
btn.render()
|
281 |
+
elif isinstance(btn, str):
|
282 |
+
btn = Button(btn, variant="secondary", size="sm")
|
283 |
+
else:
|
284 |
+
raise ValueError(
|
285 |
+
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
286 |
+
)
|
287 |
+
self.buttons.append(btn) # type: ignore
|
288 |
+
|
289 |
+
with Group():
|
290 |
+
with Row():
|
291 |
+
if textbox:
|
292 |
+
textbox.container = False
|
293 |
+
textbox.show_label = False
|
294 |
+
textbox_ = textbox.render()
|
295 |
+
assert isinstance(textbox_, Textbox)
|
296 |
+
self.textbox = textbox_
|
297 |
+
else:
|
298 |
+
self.textbox = Textbox(
|
299 |
+
container=False,
|
300 |
+
show_label=False,
|
301 |
+
label="Message",
|
302 |
+
placeholder="Type a message...",
|
303 |
+
scale=7,
|
304 |
+
autofocus=autofocus,
|
305 |
+
)
|
306 |
+
if submit_btn is not None:
|
307 |
+
if isinstance(submit_btn, Button):
|
308 |
+
submit_btn.render()
|
309 |
+
elif isinstance(submit_btn, str):
|
310 |
+
submit_btn = Button(
|
311 |
+
submit_btn,
|
312 |
+
variant="primary",
|
313 |
+
scale=2,
|
314 |
+
min_width=150,
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
raise ValueError(
|
318 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
319 |
+
)
|
320 |
+
if stop_btn is not None:
|
321 |
+
if isinstance(stop_btn, Button):
|
322 |
+
stop_btn.visible = False
|
323 |
+
stop_btn.render()
|
324 |
+
elif isinstance(stop_btn, str):
|
325 |
+
stop_btn = Button(
|
326 |
+
stop_btn,
|
327 |
+
variant="stop",
|
328 |
+
visible=False,
|
329 |
+
scale=2,
|
330 |
+
min_width=150,
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
raise ValueError(
|
334 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
335 |
+
)
|
336 |
+
self.num_tokens = Textbox(
|
337 |
+
container=False,
|
338 |
+
show_label=False,
|
339 |
+
label="num_tokens",
|
340 |
+
placeholder="0 tokens",
|
341 |
+
scale=1,
|
342 |
+
interactive=False,
|
343 |
+
# autofocus=autofocus,
|
344 |
+
min_width=10
|
345 |
+
)
|
346 |
+
self.buttons.extend([submit_btn, stop_btn]) # type: ignore
|
347 |
+
|
348 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
349 |
+
self.fake_response_textbox = Textbox(label="Response", visible=False)
|
350 |
+
(
|
351 |
+
self.retry_btn,
|
352 |
+
self.undo_btn,
|
353 |
+
self.clear_btn,
|
354 |
+
self.submit_btn,
|
355 |
+
self.stop_btn,
|
356 |
+
) = self.buttons
|
357 |
+
|
358 |
+
|
359 |
+
any_unrendered_inputs = any(
|
360 |
+
not inp.is_rendered for inp in self.additional_inputs
|
361 |
+
)
|
362 |
+
if self.add_multimodal_fn is not None:
|
363 |
+
with Row():
|
364 |
+
self.multimodal_inputs = self.add_multimodal_fn()
|
365 |
+
if self.additional_inputs and any_unrendered_inputs:
|
366 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
367 |
+
if self.render_additional_inputs_fn is not None:
|
368 |
+
self.render_additional_inputs_fn()
|
369 |
+
else:
|
370 |
+
for input_component in self.additional_inputs:
|
371 |
+
if not input_component.is_rendered:
|
372 |
+
input_component.render()
|
373 |
+
else:
|
374 |
+
if self.additional_inputs and any_unrendered_inputs:
|
375 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
376 |
+
if self.render_additional_inputs_fn is not None:
|
377 |
+
self.render_additional_inputs_fn()
|
378 |
+
else:
|
379 |
+
for input_component in self.additional_inputs:
|
380 |
+
if not input_component.is_rendered:
|
381 |
+
input_component.render()
|
382 |
+
|
383 |
+
if examples:
|
384 |
+
if self.is_generator:
|
385 |
+
examples_fn = self._examples_stream_fn
|
386 |
+
else:
|
387 |
+
# examples_fn = self._examples_fn
|
388 |
+
raise NotImplementedError(f'Not streaming not impl')
|
389 |
+
|
390 |
+
self.examples_handler = Examples(
|
391 |
+
examples=examples,
|
392 |
+
inputs=[self.textbox] + self.multimodal_inputs + self.additional_inputs,
|
393 |
+
outputs=self.chatbot,
|
394 |
+
fn=examples_fn,
|
395 |
+
)
|
396 |
+
|
397 |
+
# The example caching must happen after the input components have rendered
|
398 |
+
if cache_examples:
|
399 |
+
client_utils.synchronize_async(self.examples_handler.cache)
|
400 |
+
|
401 |
+
self.saved_input = State()
|
402 |
+
self.chatbot_state = (
|
403 |
+
State(self.chatbot.value) if self.chatbot.value else State([])
|
404 |
+
)
|
405 |
+
|
406 |
+
self._setup_events()
|
407 |
+
self._setup_api()
|
408 |
+
|
409 |
+
def _clear_and_save_textbox(self, message: str, *multimodal_inputs) -> tuple[str, str]:
|
410 |
+
saved_input = [message] + list(multimodal_inputs)
|
411 |
+
outputs = [''] + [None] * len(multimodal_inputs)
|
412 |
+
return outputs + [saved_input]
|
413 |
+
|
414 |
+
def _add_inputs_to_history(self, history: List[List[Union[str, None]]], *args):
|
415 |
+
message = args[0]
|
416 |
+
multimodal_inputs = args[1:1 + len(self.multimodal_inputs)] if len(args) > 1 else None
|
417 |
+
if multimodal_inputs is not None:
|
418 |
+
is_file_exists = [(x is not None and os.path.exists(x)) for x in multimodal_inputs]
|
419 |
+
if any(is_file_exists):
|
420 |
+
file_exists = [f for f, ise in zip(multimodal_inputs, is_file_exists) if ise]
|
421 |
+
if len(file_exists) > 1:
|
422 |
+
raise gr.Error(f"Cannot have more than 1 multimodal input at a time.")
|
423 |
+
fname = file_exists[0]
|
424 |
+
history.append([(fname,), None])
|
425 |
+
if message is not None and message.strip() != "":
|
426 |
+
history.append([message, None])
|
427 |
+
return history
|
428 |
+
|
429 |
+
|
430 |
+
def _display_input(
|
431 |
+
self, saved_input: List[str], history: List[List[Union[str, None]]]
|
432 |
+
) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
|
433 |
+
# message = saved_input[0]
|
434 |
+
# multimodal_inputs = saved_input[1:] if len(saved_input) > 1 else None
|
435 |
+
# # ! If things wrong, return original history and give warning
|
436 |
+
# if multimodal_inputs is not None:
|
437 |
+
# is_file_exists = [(x is not None and os.path.exists(x)) for x in multimodal_inputs]
|
438 |
+
# if any(is_file_exists):
|
439 |
+
# file_exists = [f for f, ise in zip(multimodal_inputs, is_file_exists) if ise]
|
440 |
+
# if len(file_exists) > 1:
|
441 |
+
# raise gr.Error(f"Cannot have more than 1 multimodal input at a time.")
|
442 |
+
# fname = file_exists[0]
|
443 |
+
# history.append([(fname,), None])
|
444 |
+
# if message is not None and message.strip() != "":
|
445 |
+
# history.append([message, None])
|
446 |
+
history = self._add_inputs_to_history(history, *saved_input)
|
447 |
+
return history, history
|
448 |
+
|
449 |
+
def _delete_prev_fn(
|
450 |
+
self, history: list[list[str | None]]
|
451 |
+
) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
|
452 |
+
try:
|
453 |
+
message, _ = history.pop()
|
454 |
+
except IndexError:
|
455 |
+
message = ""
|
456 |
+
saved_input = [message or ""] + [None] * len(self.multimodal_inputs)
|
457 |
+
return history, saved_input, history
|
458 |
+
|
459 |
+
def _setup_events(self) -> None:
|
460 |
+
from gradio.components import State
|
461 |
+
has_on = False
|
462 |
+
try:
|
463 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
464 |
+
has_on = True
|
465 |
+
except ImportError as ie:
|
466 |
+
has_on = False
|
467 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
468 |
+
if not self.is_generator:
|
469 |
+
raise NotImplementedError(f'should use generator')
|
470 |
+
|
471 |
+
if has_on:
|
472 |
+
# new version
|
473 |
+
submit_triggers = (
|
474 |
+
[self.textbox.submit, self.submit_btn.click]
|
475 |
+
if self.submit_btn
|
476 |
+
else [self.textbox.submit]
|
477 |
+
)
|
478 |
+
submit_event = (
|
479 |
+
on(
|
480 |
+
submit_triggers,
|
481 |
+
self._clear_and_save_textbox,
|
482 |
+
[self.textbox] + self.multimodal_inputs,
|
483 |
+
[self.textbox] + self.multimodal_inputs + [self.saved_input],
|
484 |
+
api_name=False,
|
485 |
+
queue=False,
|
486 |
+
)
|
487 |
+
.then(
|
488 |
+
self._display_input,
|
489 |
+
[self.saved_input, self.chatbot_state],
|
490 |
+
[self.chatbot, self.chatbot_state],
|
491 |
+
api_name=False,
|
492 |
+
queue=False,
|
493 |
+
)
|
494 |
+
.success(
|
495 |
+
submit_fn,
|
496 |
+
[self.chatbot_state] + self.additional_inputs,
|
497 |
+
[self.chatbot, self.chatbot_state, self.num_tokens],
|
498 |
+
api_name=False,
|
499 |
+
)
|
500 |
+
)
|
501 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
502 |
+
else:
|
503 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
504 |
+
|
505 |
+
if self.retry_btn:
|
506 |
+
retry_event = (
|
507 |
+
self.retry_btn.click(
|
508 |
+
self._delete_prev_fn,
|
509 |
+
[self.chatbot_state],
|
510 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
511 |
+
api_name=False,
|
512 |
+
queue=False,
|
513 |
+
)
|
514 |
+
.then(
|
515 |
+
self._display_input,
|
516 |
+
[self.saved_input, self.chatbot_state],
|
517 |
+
[self.chatbot, self.chatbot_state],
|
518 |
+
api_name=False,
|
519 |
+
queue=False,
|
520 |
+
)
|
521 |
+
.success(
|
522 |
+
submit_fn,
|
523 |
+
[self.chatbot_state] + self.additional_inputs,
|
524 |
+
[self.chatbot, self.chatbot_state, self.num_tokens],
|
525 |
+
api_name=False,
|
526 |
+
)
|
527 |
+
)
|
528 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
529 |
+
|
530 |
+
if self.undo_btn:
|
531 |
+
self.undo_btn.click(
|
532 |
+
# self._delete_prev_fn,
|
533 |
+
# [self.chatbot_state],
|
534 |
+
# [self.chatbot, self.saved_input, self.chatbot_state],
|
535 |
+
undo_history_until_last_assistant_turn,
|
536 |
+
[self.chatbot_state],
|
537 |
+
[self.chatbot, self.chatbot_state],
|
538 |
+
api_name=False,
|
539 |
+
queue=False,
|
540 |
+
)
|
541 |
+
# .then(
|
542 |
+
# lambda x: x,
|
543 |
+
# [self.saved_input],
|
544 |
+
# [self.textbox],
|
545 |
+
# api_name=False,
|
546 |
+
# queue=False,
|
547 |
+
# )
|
548 |
+
|
549 |
+
async def _stream_fn(
|
550 |
+
self,
|
551 |
+
# message: str,
|
552 |
+
history_with_input,
|
553 |
+
request: Request,
|
554 |
+
*args,
|
555 |
+
) -> AsyncGenerator:
|
556 |
+
history = history_with_input[:-1]
|
557 |
+
message = history_with_input[-1][0]
|
558 |
+
inputs, _, _ = special_args(
|
559 |
+
self.fn, inputs=[history_with_input, *args], request=request
|
560 |
+
)
|
561 |
+
|
562 |
+
if self.is_async:
|
563 |
+
generator = self.fn(*inputs)
|
564 |
+
else:
|
565 |
+
generator = await anyio.to_thread.run_sync(
|
566 |
+
self.fn, *inputs, limiter=self.limiter
|
567 |
+
)
|
568 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
569 |
+
|
570 |
+
# ! In case of error, yield the previous history & undo any generation before raising error
|
571 |
+
try:
|
572 |
+
first_response_pack = await async_iteration(generator)
|
573 |
+
if isinstance(first_response_pack, (tuple, list)):
|
574 |
+
first_response, num_tokens = first_response_pack
|
575 |
+
else:
|
576 |
+
first_response, num_tokens = first_response_pack, -1
|
577 |
+
update = history + [[message, first_response]]
|
578 |
+
yield update, update, f"{num_tokens} toks"
|
579 |
+
except StopIteration:
|
580 |
+
update = history + [[message, None]]
|
581 |
+
yield update, update, "NaN toks"
|
582 |
+
except Exception as e:
|
583 |
+
yield history, history, "NaN toks"
|
584 |
+
raise e
|
585 |
+
|
586 |
+
try:
|
587 |
+
async for response_pack in generator:
|
588 |
+
if isinstance(response_pack, (tuple, list)):
|
589 |
+
response, num_tokens = response_pack
|
590 |
+
else:
|
591 |
+
response, num_tokens = response_pack, "NaN toks"
|
592 |
+
update = history + [[message, response]]
|
593 |
+
yield update, update, f"{num_tokens} toks"
|
594 |
+
except Exception as e:
|
595 |
+
yield history, history, "NaN toks"
|
596 |
+
raise e
|
597 |
+
|
598 |
+
async def _examples_stream_fn(
|
599 |
+
self,
|
600 |
+
# message: str,
|
601 |
+
*args,
|
602 |
+
) -> AsyncGenerator:
|
603 |
+
history = []
|
604 |
+
input_len = 1 + len(self.multimodal_inputs)
|
605 |
+
saved_input = args[:input_len]
|
606 |
+
message = saved_input[0]
|
607 |
+
additional_inputs = [] if len(args) <= input_len else args[input_len:]
|
608 |
+
history = self._add_inputs_to_history(history, *saved_input)
|
609 |
+
inputs, _, _ = special_args(self.fn, inputs=[history, *additional_inputs], request=None)
|
610 |
+
|
611 |
+
if self.is_async:
|
612 |
+
generator = self.fn(*inputs)
|
613 |
+
else:
|
614 |
+
generator = await anyio.to_thread.run_sync(
|
615 |
+
self.fn, *inputs, limiter=self.limiter
|
616 |
+
)
|
617 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
618 |
+
# async for response in generator:
|
619 |
+
# yield [[message, response]]
|
620 |
+
|
621 |
+
try:
|
622 |
+
async for response_pack in generator:
|
623 |
+
if isinstance(response_pack, (tuple, list)):
|
624 |
+
response, num_tokens = response_pack
|
625 |
+
else:
|
626 |
+
response, num_tokens = response_pack, "NaN toks"
|
627 |
+
update = history + [[message, response]]
|
628 |
+
yield update, update, f"{num_tokens} toks"
|
629 |
+
except Exception as e:
|
630 |
+
yield history, history, "NaN toks"
|
631 |
+
raise e
|
632 |
+
|
633 |
+
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
634 |
+
raise NotImplementedError
|
635 |
+
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
636 |
+
|
637 |
+
if self.is_async:
|
638 |
+
response = await self.fn(*inputs)
|
639 |
+
else:
|
640 |
+
response = await anyio.to_thread.run_sync(
|
641 |
+
self.fn, *inputs, limiter=self.limiter
|
642 |
+
)
|
643 |
+
return [[message, response]]
|
644 |
+
|
645 |
+
|
646 |
+
|
647 |
+
def gradio_history_to_openai_conversations(message=None, history=None, system_prompt=None):
|
648 |
+
conversations = []
|
649 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
650 |
+
if history is not None and len(history) > 0:
|
651 |
+
for i, (prompt, res) in enumerate(history):
|
652 |
+
if prompt is not None:
|
653 |
+
conversations.append({"role": "user", "content": prompt.strip()})
|
654 |
+
if res is not None:
|
655 |
+
conversations.append({"role": "assistant", "content": res.strip()})
|
656 |
+
if message is not None:
|
657 |
+
if len(message.strip()) == 0:
|
658 |
+
raise gr.Error("The message cannot be empty!")
|
659 |
+
conversations.append({"role": "user", "content": message.strip()})
|
660 |
+
if conversations[0]['role'] != 'system':
|
661 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
662 |
+
return conversations
|
663 |
+
|
664 |
+
|
665 |
+
def gradio_history_to_conversation_prompt(message=None, history=None, system_prompt=None):
|
666 |
+
global MODEL_ENGINE
|
667 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
668 |
+
gradio_history_to_openai_conversations(
|
669 |
+
message, history=history, system_prompt=system_prompt),
|
670 |
+
add_generation_prompt=True
|
671 |
+
)
|
672 |
+
return full_prompt
|
673 |
+
|
674 |
+
|
675 |
+
def gradio_history_to_vision_conversation_prompt_paths(
|
676 |
+
history, system_prompt=None, image_token=None
|
677 |
+
):
|
678 |
+
"""
|
679 |
+
Aggregate gradio history into openai conversations
|
680 |
+
history = [
|
681 |
+
["Hello", "Response"],
|
682 |
+
[(file,), None],
|
683 |
+
]
|
684 |
+
--->
|
685 |
+
[
|
686 |
+
{"role": "user", "content": ...}
|
687 |
+
]
|
688 |
+
"""
|
689 |
+
global MODEL_ENGINE
|
690 |
+
image_token = image_token or IMAGE_TOKEN
|
691 |
+
conversations = []
|
692 |
+
image_paths = []
|
693 |
+
for i, his in enumerate(history):
|
694 |
+
prompt, response = his
|
695 |
+
last_turn = conversations[-1] if len(conversations) > 0 else None
|
696 |
+
if prompt is not None:
|
697 |
+
if isinstance(prompt, tuple):
|
698 |
+
image_path = prompt[0]
|
699 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
700 |
+
last_turn['content'] += f" {image_token}"
|
701 |
+
else:
|
702 |
+
# last_turn None or last_turn['role'] == 'assistant'
|
703 |
+
conversations.append({
|
704 |
+
"role": "user",
|
705 |
+
"content": f"{image_token}"
|
706 |
+
})
|
707 |
+
image_paths.append(image_path)
|
708 |
+
else:
|
709 |
+
assert prompt is not None and isinstance(prompt, str)
|
710 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
711 |
+
last_turn['content'] += f"\n{prompt}"
|
712 |
+
else:
|
713 |
+
conversations.append({
|
714 |
+
"role": "user",
|
715 |
+
"content": prompt,
|
716 |
+
})
|
717 |
+
if response is not None:
|
718 |
+
assert isinstance(response, str)
|
719 |
+
conversations.append({
|
720 |
+
"role": "assistant",
|
721 |
+
"content": response,
|
722 |
+
})
|
723 |
+
|
724 |
+
if conversations[0]['role'] != 'system':
|
725 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
726 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
727 |
+
|
728 |
+
# print(f'convo: {json.dumps(conversations, indent=4, ensure_ascii=False)}\n{image_paths=}')
|
729 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
730 |
+
conversations,
|
731 |
+
add_generation_prompt=True
|
732 |
+
)
|
733 |
+
return full_prompt, image_paths, conversations
|
734 |
+
|
735 |
+
|
736 |
+
def is_doc(file_path):
|
737 |
+
is_doc_allowed = file_path.endswith((".pdf", ".docx", ".txt"))
|
738 |
+
return is_doc_allowed
|
739 |
+
|
740 |
+
|
741 |
+
def read_doc(file_path):
|
742 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
|
743 |
+
if file_path.endswith('.pdf'):
|
744 |
+
loader = PyPDFLoader(file_path)
|
745 |
+
elif file_path.endswith('.docx'):
|
746 |
+
loader = Docx2txtLoader(file_path)
|
747 |
+
elif file_path.endswith('.txt'):
|
748 |
+
loader = TextLoader(file_path)
|
749 |
+
texts = loader.load()
|
750 |
+
text = "\n\n".join([t.page_content for t in texts])
|
751 |
+
return text
|
752 |
+
|
753 |
+
|
754 |
+
def doc_file_to_instruct_content(file_path, doc_instruction=None):
|
755 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
756 |
+
content = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=read_doc(file_path))
|
757 |
+
return content
|
758 |
+
|
759 |
+
|
760 |
+
def gradio_history_to_doc_conversation_prompt(
|
761 |
+
history, system_prompt=None, doc_instruction=None,
|
762 |
+
):
|
763 |
+
"""
|
764 |
+
Aggregate gradio history into openai conversations
|
765 |
+
history = [
|
766 |
+
["Hello", "Response"],
|
767 |
+
[(file,), None],
|
768 |
+
]
|
769 |
+
--->
|
770 |
+
[
|
771 |
+
{"role": "user", "content": ...}
|
772 |
+
]
|
773 |
+
"""
|
774 |
+
global MODEL_ENGINE
|
775 |
+
# image_token = image_token or IMAGE_TOKEN
|
776 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
777 |
+
conversations = []
|
778 |
+
image_paths = []
|
779 |
+
for i, his in enumerate(history):
|
780 |
+
prompt, response = his
|
781 |
+
last_turn = conversations[-1] if len(conversations) > 0 else None
|
782 |
+
if prompt is not None:
|
783 |
+
if isinstance(prompt, tuple):
|
784 |
+
file_path = prompt[0]
|
785 |
+
if not is_doc(file_path):
|
786 |
+
raise gr.Error(f'file not doc {file_path}')
|
787 |
+
content = doc_file_to_instruct_content(file_path, doc_instruction)
|
788 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
789 |
+
last_turn['content'] += f"{content}"
|
790 |
+
else:
|
791 |
+
# last_turn None or last_turn['role'] == 'assistant'
|
792 |
+
conversations.append({
|
793 |
+
"role": "user",
|
794 |
+
"content": f"{content}"
|
795 |
+
})
|
796 |
+
else:
|
797 |
+
assert prompt is not None and isinstance(prompt, str)
|
798 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
799 |
+
last_turn['content'] += f"\n{prompt}"
|
800 |
+
else:
|
801 |
+
conversations.append({
|
802 |
+
"role": "user",
|
803 |
+
"content": prompt,
|
804 |
+
})
|
805 |
+
if response is not None:
|
806 |
+
assert isinstance(response, str)
|
807 |
+
conversations.append({
|
808 |
+
"role": "assistant",
|
809 |
+
"content": response,
|
810 |
+
})
|
811 |
+
|
812 |
+
if conversations[0]['role'] != 'system':
|
813 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
814 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
815 |
+
|
816 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
817 |
+
conversations,
|
818 |
+
add_generation_prompt=True
|
819 |
+
)
|
820 |
+
return full_prompt, conversations
|
821 |
+
|
822 |
+
|
823 |
+
def gradio_history_to_vision_doc_conversation_prompt_paths(
|
824 |
+
history, system_prompt=None, image_token=None, doc_instruction=None,
|
825 |
+
):
|
826 |
+
"""
|
827 |
+
Aggregate gradio history into openai conversations
|
828 |
+
history = [
|
829 |
+
["Hello", "Response"],
|
830 |
+
[(file,), None],
|
831 |
+
]
|
832 |
+
--->
|
833 |
+
[
|
834 |
+
{"role": "user", "content": ...}
|
835 |
+
]
|
836 |
+
"""
|
837 |
+
global MODEL_ENGINE
|
838 |
+
image_token = image_token or IMAGE_TOKEN
|
839 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
840 |
+
conversations = []
|
841 |
+
image_paths = []
|
842 |
+
for i, his in enumerate(history):
|
843 |
+
prompt, response = his
|
844 |
+
last_turn = conversations[-1] if len(conversations) > 0 else None
|
845 |
+
if prompt is not None:
|
846 |
+
if isinstance(prompt, tuple):
|
847 |
+
file_path = prompt[0]
|
848 |
+
if is_doc(file_path):
|
849 |
+
content = doc_file_to_instruct_content(file_path, doc_instruction)
|
850 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
851 |
+
last_turn['content'] += f"{content}"
|
852 |
+
else:
|
853 |
+
# last_turn None or last_turn['role'] == 'assistant'
|
854 |
+
conversations.append({
|
855 |
+
"role": "user",
|
856 |
+
"content": f"{content}"
|
857 |
+
})
|
858 |
+
else:
|
859 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
860 |
+
last_turn['content'] += f" {image_token}"
|
861 |
+
else:
|
862 |
+
# last_turn None or last_turn['role'] == 'assistant'
|
863 |
+
conversations.append({
|
864 |
+
"role": "user",
|
865 |
+
"content": f"{image_token}"
|
866 |
+
})
|
867 |
+
image_paths.append(file_path)
|
868 |
+
else:
|
869 |
+
assert prompt is not None and isinstance(prompt, str)
|
870 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
871 |
+
last_turn['content'] += f"\n{prompt}"
|
872 |
+
else:
|
873 |
+
conversations.append({
|
874 |
+
"role": "user",
|
875 |
+
"content": prompt,
|
876 |
+
})
|
877 |
+
if response is not None:
|
878 |
+
assert isinstance(response, str)
|
879 |
+
conversations.append({
|
880 |
+
"role": "assistant",
|
881 |
+
"content": response,
|
882 |
+
})
|
883 |
+
|
884 |
+
if conversations[0]['role'] != 'system':
|
885 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
886 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
887 |
+
|
888 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
889 |
+
conversations,
|
890 |
+
add_generation_prompt=True
|
891 |
+
)
|
892 |
+
return full_prompt, image_paths, conversations
|
893 |
+
|
894 |
+
|
895 |
+
def vision_chat_response_stream_multiturn_engine(
|
896 |
+
history: List[Tuple[str, str]],
|
897 |
+
temperature: float,
|
898 |
+
max_tokens: int,
|
899 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
900 |
+
image_token: Optional[str] = IMAGE_TOKEN,
|
901 |
+
):
|
902 |
+
global MODEL_ENGINE
|
903 |
+
temperature = float(temperature)
|
904 |
+
# ! remove frequency_penalty
|
905 |
+
# frequency_penalty = float(frequency_penalty)
|
906 |
+
max_tokens = int(max_tokens)
|
907 |
+
# ! skip safety
|
908 |
+
if DATETIME_FORMAT in system_prompt:
|
909 |
+
# ! This sometime works sometimes dont
|
910 |
+
system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
|
911 |
+
# ! history now can have multimodal
|
912 |
+
|
913 |
+
full_prompt, image_paths, conversations = gradio_history_to_vision_conversation_prompt_paths(
|
914 |
+
history=history, system_prompt=system_prompt, image_token=image_token
|
915 |
+
)
|
916 |
+
|
917 |
+
if hasattr(MODEL_ENGINE, "get_multimodal_tokens"):
|
918 |
+
num_tokens = MODEL_ENGINE.get_multimodal_tokens(full_prompt, image_paths=image_paths)
|
919 |
+
else:
|
920 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
|
921 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
922 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
923 |
+
|
924 |
+
print(f'{image_paths=}')
|
925 |
+
print(full_prompt)
|
926 |
+
outputs = None
|
927 |
+
response = None
|
928 |
+
num_tokens = -1
|
929 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
930 |
+
prompt=full_prompt,
|
931 |
+
temperature=temperature,
|
932 |
+
max_tokens=max_tokens,
|
933 |
+
image_paths=image_paths,
|
934 |
+
)):
|
935 |
+
if isinstance(outputs, tuple):
|
936 |
+
response, num_tokens = outputs
|
937 |
+
else:
|
938 |
+
response, num_tokens = outputs, -1
|
939 |
+
yield response, num_tokens
|
940 |
+
|
941 |
+
print(format_conversation(history + [[None, response]]))
|
942 |
+
|
943 |
+
if response is not None:
|
944 |
+
yield response, num_tokens
|
945 |
+
|
946 |
+
|
947 |
+
def doc_chat_response_stream_multiturn_engine(
|
948 |
+
history: List[Tuple[str, str]],
|
949 |
+
temperature: float,
|
950 |
+
max_tokens: int,
|
951 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
952 |
+
doc_instruction: Optional[str] = DOC_INSTRUCTION,
|
953 |
+
):
|
954 |
+
global MODEL_ENGINE
|
955 |
+
temperature = float(temperature)
|
956 |
+
# ! remove frequency_penalty
|
957 |
+
# frequency_penalty = float(frequency_penalty)
|
958 |
+
max_tokens = int(max_tokens)
|
959 |
+
# ! skip safety
|
960 |
+
if DATETIME_FORMAT in system_prompt:
|
961 |
+
# ! This sometime works sometimes dont
|
962 |
+
system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
|
963 |
+
# ! history now can have multimodal
|
964 |
+
|
965 |
+
full_prompt, conversations = gradio_history_to_doc_conversation_prompt(
|
966 |
+
history=history, system_prompt=system_prompt, doc_instruction=doc_instruction
|
967 |
+
)
|
968 |
+
|
969 |
+
# ! length checked
|
970 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
|
971 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
972 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
973 |
+
|
974 |
+
print(full_prompt)
|
975 |
+
outputs = None
|
976 |
+
response = None
|
977 |
+
num_tokens = -1
|
978 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
979 |
+
prompt=full_prompt,
|
980 |
+
temperature=temperature,
|
981 |
+
max_tokens=max_tokens,
|
982 |
+
# image_paths=image_paths,
|
983 |
+
)):
|
984 |
+
if isinstance(outputs, tuple):
|
985 |
+
response, num_tokens = outputs
|
986 |
+
else:
|
987 |
+
response, num_tokens = outputs, -1
|
988 |
+
yield response, num_tokens
|
989 |
+
|
990 |
+
print(format_conversation(history + [[None, response]]))
|
991 |
+
|
992 |
+
if response is not None:
|
993 |
+
yield response, num_tokens
|
994 |
+
|
995 |
+
|
996 |
+
|
997 |
+
|
998 |
+
def vision_doc_chat_response_stream_multiturn_engine(
|
999 |
+
history: List[Tuple[str, str]],
|
1000 |
+
temperature: float,
|
1001 |
+
max_tokens: int,
|
1002 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
1003 |
+
image_token: Optional[str] = IMAGE_TOKEN,
|
1004 |
+
doc_instruction: Optional[str] = DOC_INSTRUCTION,
|
1005 |
+
):
|
1006 |
+
global MODEL_ENGINE
|
1007 |
+
temperature = float(temperature)
|
1008 |
+
# ! remove frequency_penalty
|
1009 |
+
# frequency_penalty = float(frequency_penalty)
|
1010 |
+
max_tokens = int(max_tokens)
|
1011 |
+
# ! skip safety
|
1012 |
+
if DATETIME_FORMAT in system_prompt:
|
1013 |
+
# ! This sometime works sometimes dont
|
1014 |
+
system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
|
1015 |
+
# ! history now can have multimodal
|
1016 |
+
|
1017 |
+
full_prompt, image_paths, conversations = gradio_history_to_vision_doc_conversation_prompt_paths(
|
1018 |
+
history=history, system_prompt=system_prompt, image_token=image_token, doc_instruction=doc_instruction
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
# ! length check
|
1022 |
+
if hasattr(MODEL_ENGINE, "get_multimodal_tokens"):
|
1023 |
+
num_tokens = MODEL_ENGINE.get_multimodal_tokens(full_prompt, image_paths=image_paths)
|
1024 |
+
else:
|
1025 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
|
1026 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
1027 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
1028 |
+
|
1029 |
+
print(full_prompt)
|
1030 |
+
print(f'{image_paths=}')
|
1031 |
+
outputs = None
|
1032 |
+
response = None
|
1033 |
+
num_tokens = -1
|
1034 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
1035 |
+
prompt=full_prompt,
|
1036 |
+
temperature=temperature,
|
1037 |
+
max_tokens=max_tokens,
|
1038 |
+
image_paths=image_paths,
|
1039 |
+
)):
|
1040 |
+
if isinstance(outputs, tuple):
|
1041 |
+
response, num_tokens = outputs
|
1042 |
+
else:
|
1043 |
+
response, num_tokens = outputs, -1
|
1044 |
+
yield response, num_tokens
|
1045 |
+
|
1046 |
+
print(format_conversation(history + [[None, response]]))
|
1047 |
+
|
1048 |
+
if response is not None:
|
1049 |
+
yield response, num_tokens
|
1050 |
+
|
1051 |
+
|
1052 |
+
|
1053 |
+
@register_demo
|
1054 |
+
class VisionChatInterfaceDemo(ChatInterfaceDemo):
|
1055 |
+
"""
|
1056 |
+
Accept vision image
|
1057 |
+
"""
|
1058 |
+
|
1059 |
+
@property
|
1060 |
+
def tab_name(self):
|
1061 |
+
return "Vision Chat"
|
1062 |
+
|
1063 |
+
@property
|
1064 |
+
def examples(self):
|
1065 |
+
return [
|
1066 |
+
["What's strange about this image?", "assets/dog_monalisa.jpeg",],
|
1067 |
+
["Explain why the sky is blue.", None,],
|
1068 |
+
]
|
1069 |
+
|
1070 |
+
def create_demo(
|
1071 |
+
self,
|
1072 |
+
title: str | None = None,
|
1073 |
+
description: str | None = None,
|
1074 |
+
**kwargs
|
1075 |
+
) -> gr.Blocks:
|
1076 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
1077 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
1078 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
1079 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
1080 |
+
description = description or """Upload an image to ask question about it."""
|
1081 |
+
|
1082 |
+
def add_multimodal_fn() -> List[Component]:
|
1083 |
+
image_input = gr.Image(label="Input Image", type="filepath", )
|
1084 |
+
return [image_input]
|
1085 |
+
|
1086 |
+
additional_inputs = [
|
1087 |
+
gr.Number(value=temperature, label='Temperature', min_width=20),
|
1088 |
+
gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
|
1089 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=1),
|
1090 |
+
gr.Textbox(value=IMAGE_TOKEN, label='Visual token', lines=1, interactive=IMAGE_TOKEN_INTERACTIVE, min_width=20),
|
1091 |
+
]
|
1092 |
+
def render_additional_inputs_fn():
|
1093 |
+
with Row():
|
1094 |
+
additional_inputs[0].render()
|
1095 |
+
additional_inputs[1].render()
|
1096 |
+
additional_inputs[3].render()
|
1097 |
+
additional_inputs[2].render()
|
1098 |
+
|
1099 |
+
demo_chat = MultiModalChatInterface(
|
1100 |
+
vision_chat_response_stream_multiturn_engine,
|
1101 |
+
chatbot=gr.Chatbot(
|
1102 |
+
label=model_name,
|
1103 |
+
bubble_full_width=False,
|
1104 |
+
latex_delimiters=[
|
1105 |
+
{ "left": "$", "right": "$", "display": False},
|
1106 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1107 |
+
],
|
1108 |
+
show_copy_button=True,
|
1109 |
+
layout="panel" if USE_PANEL else "bubble",
|
1110 |
+
height=CHATBOT_HEIGHT,
|
1111 |
+
),
|
1112 |
+
# textbox=gr.Textbox(placeholder='Type message', lines=4, max_lines=128, min_width=200),
|
1113 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
1114 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1115 |
+
# ! consider preventing the stop button
|
1116 |
+
# stop_btn=None,
|
1117 |
+
add_multimodal_fn=add_multimodal_fn,
|
1118 |
+
title=title,
|
1119 |
+
description=description,
|
1120 |
+
additional_inputs=additional_inputs,
|
1121 |
+
render_additional_inputs_fn=render_additional_inputs_fn,
|
1122 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1123 |
+
examples=self.examples,
|
1124 |
+
cache_examples=False,
|
1125 |
+
css=CSS,
|
1126 |
+
)
|
1127 |
+
return demo_chat
|
1128 |
+
|
1129 |
+
|
1130 |
+
def add_document_upload():
|
1131 |
+
file_input = gr.File(label='Upload pdf, docx, txt', file_count='single', file_types=['pdf', 'docx', 'txt'])
|
1132 |
+
# ! Some platform has problems with gr.File, so use uploadbutton instead
|
1133 |
+
# with Group():
|
1134 |
+
# file_input = gr.Textbox(value=None, label='Document path', lines=1, interactive=False)
|
1135 |
+
# upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt'], file_count="single")
|
1136 |
+
# upload_button.upload(lambda x: x.name, upload_button, file_input)
|
1137 |
+
return file_input
|
1138 |
+
|
1139 |
+
|
1140 |
+
@register_demo
|
1141 |
+
class DocChatInterfaceDemo(ChatInterfaceDemo):
|
1142 |
+
"""
|
1143 |
+
Accept document (full length no RAG)
|
1144 |
+
"""
|
1145 |
+
@property
|
1146 |
+
def tab_name(self):
|
1147 |
+
return "Doc Chat"
|
1148 |
+
|
1149 |
+
@property
|
1150 |
+
def examples(self):
|
1151 |
+
return [
|
1152 |
+
["Summarize the document", "assets/attention_short.pdf",],
|
1153 |
+
["Explain why the sky is blue.", None,],
|
1154 |
+
]
|
1155 |
+
|
1156 |
+
def create_demo(
|
1157 |
+
self,
|
1158 |
+
title: str | None = None,
|
1159 |
+
description: str | None = None,
|
1160 |
+
**kwargs
|
1161 |
+
) -> gr.Blocks:
|
1162 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
1163 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
1164 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
1165 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
1166 |
+
# frequence_penalty = FREQUENCE_PENALTY
|
1167 |
+
# presence_penalty = PRESENCE_PENALTY
|
1168 |
+
description = description or """Upload a short document to ask question about it."""
|
1169 |
+
|
1170 |
+
def add_multimodal_fn() -> List[Component]:
|
1171 |
+
file_input = add_document_upload()
|
1172 |
+
# image_input = gr.Image(label="Input Image", type="filepath", )
|
1173 |
+
return [file_input]
|
1174 |
+
|
1175 |
+
additional_inputs = [
|
1176 |
+
gr.Number(value=temperature, label='Temperature', min_width=20),
|
1177 |
+
gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
|
1178 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=1),
|
1179 |
+
gr.Textbox(value=DOC_INSTRUCTION, label='Doc instruction', lines=1),
|
1180 |
+
]
|
1181 |
+
def render_additional_inputs_fn():
|
1182 |
+
with Row():
|
1183 |
+
additional_inputs[0].render()
|
1184 |
+
additional_inputs[1].render()
|
1185 |
+
additional_inputs[2].render()
|
1186 |
+
additional_inputs[3].render()
|
1187 |
+
|
1188 |
+
demo_chat = MultiModalChatInterface(
|
1189 |
+
doc_chat_response_stream_multiturn_engine,
|
1190 |
+
chatbot=gr.Chatbot(
|
1191 |
+
label=model_name,
|
1192 |
+
bubble_full_width=False,
|
1193 |
+
latex_delimiters=[
|
1194 |
+
{ "left": "$", "right": "$", "display": False},
|
1195 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1196 |
+
],
|
1197 |
+
show_copy_button=True,
|
1198 |
+
layout="panel" if USE_PANEL else "bubble",
|
1199 |
+
height=CHATBOT_HEIGHT,
|
1200 |
+
),
|
1201 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
1202 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1203 |
+
# ! consider preventing the stop button
|
1204 |
+
add_multimodal_fn=add_multimodal_fn,
|
1205 |
+
title=title,
|
1206 |
+
description=description,
|
1207 |
+
additional_inputs=additional_inputs,
|
1208 |
+
render_additional_inputs_fn=render_additional_inputs_fn,
|
1209 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1210 |
+
examples=self.examples,
|
1211 |
+
cache_examples=False,
|
1212 |
+
css=CSS,
|
1213 |
+
)
|
1214 |
+
return demo_chat
|
1215 |
+
|
1216 |
+
|
1217 |
+
@register_demo
|
1218 |
+
class VisionDocChatInterfaceDemo(ChatInterfaceDemo):
|
1219 |
+
"""
|
1220 |
+
Accept either vision image or document (full length no RAG)
|
1221 |
+
"""
|
1222 |
+
@property
|
1223 |
+
def tab_name(self):
|
1224 |
+
return "Vision Doc Chat"
|
1225 |
+
|
1226 |
+
@property
|
1227 |
+
def examples(self):
|
1228 |
+
return [
|
1229 |
+
["What's strange about this image?", None, "assets/dog_monalisa.jpeg",],
|
1230 |
+
["Summarize the document", "assets/attention_short.pdf", None,],
|
1231 |
+
["Explain why the sky is blue.", None, None],
|
1232 |
+
]
|
1233 |
+
|
1234 |
+
def create_demo(
|
1235 |
+
self,
|
1236 |
+
title: str | None = None,
|
1237 |
+
description: str | None = None,
|
1238 |
+
**kwargs
|
1239 |
+
) -> gr.Blocks:
|
1240 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
1241 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
1242 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
1243 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
1244 |
+
# frequence_penalty = FREQUENCE_PENALTY
|
1245 |
+
# presence_penalty = PRESENCE_PENALTY
|
1246 |
+
description = description or """Upload either an image or short document to ask question about it."""
|
1247 |
+
|
1248 |
+
def add_multimodal_fn() -> List[Component]:
|
1249 |
+
file_input = add_document_upload()
|
1250 |
+
image_input = gr.Image(label="Input Image", type="filepath", )
|
1251 |
+
return [file_input, image_input]
|
1252 |
+
|
1253 |
+
additional_inputs = [
|
1254 |
+
gr.Number(value=temperature, label='Temperature', min_width=20),
|
1255 |
+
gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
|
1256 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=1),
|
1257 |
+
gr.Textbox(value=IMAGE_TOKEN, label='Visual token', lines=1, interactive=IMAGE_TOKEN_INTERACTIVE, min_width=2),
|
1258 |
+
gr.Textbox(value=DOC_INSTRUCTION, label='Doc instruction', lines=1),
|
1259 |
+
]
|
1260 |
+
def render_additional_inputs_fn():
|
1261 |
+
with Row():
|
1262 |
+
additional_inputs[0].render()
|
1263 |
+
additional_inputs[1].render()
|
1264 |
+
additional_inputs[3].render()
|
1265 |
+
additional_inputs[2].render()
|
1266 |
+
additional_inputs[4].render()
|
1267 |
+
|
1268 |
+
demo_chat = MultiModalChatInterface(
|
1269 |
+
vision_doc_chat_response_stream_multiturn_engine,
|
1270 |
+
chatbot=gr.Chatbot(
|
1271 |
+
label=MODEL_NAME,
|
1272 |
+
bubble_full_width=False,
|
1273 |
+
latex_delimiters=[
|
1274 |
+
{ "left": "$", "right": "$", "display": False},
|
1275 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1276 |
+
],
|
1277 |
+
show_copy_button=True,
|
1278 |
+
layout="panel" if USE_PANEL else "bubble",
|
1279 |
+
height=CHATBOT_HEIGHT,
|
1280 |
+
),
|
1281 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
1282 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1283 |
+
add_multimodal_fn=add_multimodal_fn,
|
1284 |
+
title=title,
|
1285 |
+
description=description,
|
1286 |
+
additional_inputs=additional_inputs,
|
1287 |
+
render_additional_inputs_fn=render_additional_inputs_fn,
|
1288 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1289 |
+
examples=self.examples,
|
1290 |
+
cache_examples=False,
|
1291 |
+
css=CSS,
|
1292 |
+
)
|
1293 |
+
return demo_chat
|
multipurpose_chatbot/demos/rag_chat_interface.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
from gradio.themes import ThemeClass as Theme
|
25 |
+
|
26 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
27 |
+
|
28 |
+
import inspect
|
29 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
30 |
+
|
31 |
+
import anyio
|
32 |
+
from gradio_client import utils as client_utils
|
33 |
+
from gradio_client.documentation import document
|
34 |
+
|
35 |
+
from gradio.blocks import Blocks
|
36 |
+
from gradio.components import (
|
37 |
+
Button,
|
38 |
+
Chatbot,
|
39 |
+
Component,
|
40 |
+
Markdown,
|
41 |
+
State,
|
42 |
+
Textbox,
|
43 |
+
get_component_instance,
|
44 |
+
)
|
45 |
+
from gradio.events import Dependency, on
|
46 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
47 |
+
from gradio.helpers import special_args
|
48 |
+
from gradio.layouts import Accordion, Group, Row
|
49 |
+
from gradio.routes import Request
|
50 |
+
from gradio.themes import ThemeClass as Theme
|
51 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
52 |
+
|
53 |
+
|
54 |
+
from ..globals import MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, load_embeddings, get_rag_embeddings
|
55 |
+
|
56 |
+
from .chat_interface import (
|
57 |
+
SYSTEM_PROMPT,
|
58 |
+
MODEL_NAME,
|
59 |
+
MAX_TOKENS,
|
60 |
+
TEMPERATURE,
|
61 |
+
CHAT_EXAMPLES,
|
62 |
+
gradio_history_to_openai_conversations,
|
63 |
+
gradio_history_to_conversation_prompt,
|
64 |
+
DATETIME_FORMAT,
|
65 |
+
get_datetime_string,
|
66 |
+
format_conversation,
|
67 |
+
chat_response_stream_multiturn_engine,
|
68 |
+
ChatInterfaceDemo,
|
69 |
+
CustomizedChatInterface,
|
70 |
+
)
|
71 |
+
|
72 |
+
from ..configs import (
|
73 |
+
CHUNK_SIZE,
|
74 |
+
CHUNK_OVERLAP,
|
75 |
+
RAG_EMBED_MODEL_NAME,
|
76 |
+
)
|
77 |
+
|
78 |
+
RAG_CURRENT_VECTORSTORE = None
|
79 |
+
|
80 |
+
|
81 |
+
def load_document_split_vectorstore(file_path):
|
82 |
+
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
83 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
84 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
85 |
+
from langchain_community.vectorstores import Chroma, FAISS
|
86 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
|
87 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
|
88 |
+
if file_path.endswith('.pdf'):
|
89 |
+
loader = PyPDFLoader(file_path)
|
90 |
+
elif file_path.endswith('.docx'):
|
91 |
+
loader = Docx2txtLoader(file_path)
|
92 |
+
elif file_path.endswith('.txt'):
|
93 |
+
loader = TextLoader(file_path)
|
94 |
+
splits = loader.load_and_split(splitter)
|
95 |
+
RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
|
96 |
+
return RAG_CURRENT_VECTORSTORE
|
97 |
+
|
98 |
+
def docs_to_context_content(docs: List[Any]):
|
99 |
+
content = "\n".join([d.page_content for d in docs])
|
100 |
+
return content
|
101 |
+
|
102 |
+
|
103 |
+
DOC_TEMPLATE = """###
|
104 |
+
{content}
|
105 |
+
###
|
106 |
+
|
107 |
+
"""
|
108 |
+
|
109 |
+
DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \
|
110 |
+
If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
|
111 |
+
"""
|
112 |
+
|
113 |
+
|
114 |
+
def docs_to_rag_context(docs: List[Any], doc_instruction=None):
|
115 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
116 |
+
content = docs_to_context_content(docs)
|
117 |
+
context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=content)
|
118 |
+
return context
|
119 |
+
|
120 |
+
|
121 |
+
def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
|
122 |
+
doc_context = None
|
123 |
+
if file_input is not None:
|
124 |
+
if file_input == RAG_CURRENT_FILE:
|
125 |
+
# reuse
|
126 |
+
vectorstore = RAG_CURRENT_VECTORSTORE
|
127 |
+
print(f'Reuse vectorstore: {file_input}')
|
128 |
+
else:
|
129 |
+
vectorstore = load_document_split_vectorstore(file_input)
|
130 |
+
print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
|
131 |
+
RAG_CURRENT_FILE = file_input
|
132 |
+
docs = vectorstore.similarity_search(message, k=rag_num_docs)
|
133 |
+
doc_context = docs_to_rag_context(docs)
|
134 |
+
return doc_context
|
135 |
+
|
136 |
+
|
137 |
+
def chat_response_stream_multiturn_doc_engine(
|
138 |
+
message: str,
|
139 |
+
history: List[Tuple[str, str]],
|
140 |
+
file_input: Optional[str] = None,
|
141 |
+
temperature: float = 0.7,
|
142 |
+
max_tokens: int = 1024,
|
143 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
144 |
+
rag_num_docs: Optional[int] = 3,
|
145 |
+
doc_instruction: Optional[str] = DOC_INSTRUCTION,
|
146 |
+
# profile: Optional[gr.OAuthProfile] = None,
|
147 |
+
):
|
148 |
+
global MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
149 |
+
if len(message) == 0:
|
150 |
+
raise gr.Error("The message cannot be empty!")
|
151 |
+
|
152 |
+
rag_num_docs = int(rag_num_docs)
|
153 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
154 |
+
doc_context = None
|
155 |
+
if file_input is not None:
|
156 |
+
if file_input == RAG_CURRENT_FILE:
|
157 |
+
# reuse
|
158 |
+
vectorstore = RAG_CURRENT_VECTORSTORE
|
159 |
+
print(f'Reuse vectorstore: {file_input}')
|
160 |
+
else:
|
161 |
+
vectorstore = load_document_split_vectorstore(file_input)
|
162 |
+
print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
|
163 |
+
RAG_CURRENT_FILE = file_input
|
164 |
+
docs = vectorstore.similarity_search(message, k=rag_num_docs)
|
165 |
+
# doc_context = docs_to_rag_context(docs)
|
166 |
+
rag_content = docs_to_context_content(docs)
|
167 |
+
doc_context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=rag_content)
|
168 |
+
|
169 |
+
if doc_context is not None:
|
170 |
+
message = f"{doc_context}\n\n{message}"
|
171 |
+
|
172 |
+
for response, num_tokens in chat_response_stream_multiturn_engine(
|
173 |
+
message, history, temperature, max_tokens, system_prompt
|
174 |
+
):
|
175 |
+
# ! yield another content which is doc_context
|
176 |
+
yield response, num_tokens, doc_context
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
class RagChatInterface(CustomizedChatInterface):
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
fn: Callable[..., Any],
|
184 |
+
*,
|
185 |
+
chatbot: gr.Chatbot | None = None,
|
186 |
+
textbox: gr.Textbox | None = None,
|
187 |
+
additional_inputs: str | Component | list[str | Component] | None = None,
|
188 |
+
additional_inputs_accordion_name: str | None = None,
|
189 |
+
additional_inputs_accordion: str | gr.Accordion | None = None,
|
190 |
+
render_additional_inputs_fn: Callable | None = None,
|
191 |
+
examples: list[str] | None = None,
|
192 |
+
cache_examples: bool | None = None,
|
193 |
+
title: str | None = None,
|
194 |
+
description: str | None = None,
|
195 |
+
theme: Theme | str | None = None,
|
196 |
+
css: str | None = None,
|
197 |
+
js: str | None = None,
|
198 |
+
head: str | None = None,
|
199 |
+
analytics_enabled: bool | None = None,
|
200 |
+
submit_btn: str | Button | None = "Submit",
|
201 |
+
stop_btn: str | Button | None = "Stop",
|
202 |
+
retry_btn: str | Button | None = "🔄 Retry",
|
203 |
+
undo_btn: str | Button | None = "↩️ Undo",
|
204 |
+
clear_btn: str | Button | None = "🗑️ Clear",
|
205 |
+
autofocus: bool = True,
|
206 |
+
concurrency_limit: int | Literal['default'] | None = "default",
|
207 |
+
fill_height: bool = True
|
208 |
+
):
|
209 |
+
try:
|
210 |
+
super(gr.ChatInterface, self).__init__(
|
211 |
+
analytics_enabled=analytics_enabled,
|
212 |
+
mode="chat_interface",
|
213 |
+
css=css,
|
214 |
+
title=title or "Gradio",
|
215 |
+
theme=theme,
|
216 |
+
js=js,
|
217 |
+
head=head,
|
218 |
+
fill_height=fill_height,
|
219 |
+
)
|
220 |
+
except Exception as e:
|
221 |
+
# Handling some old gradio version with out fill_height
|
222 |
+
super(gr.ChatInterface, self).__init__(
|
223 |
+
analytics_enabled=analytics_enabled,
|
224 |
+
mode="chat_interface",
|
225 |
+
css=css,
|
226 |
+
title=title or "Gradio",
|
227 |
+
theme=theme,
|
228 |
+
js=js,
|
229 |
+
head=head,
|
230 |
+
# fill_height=fill_height,
|
231 |
+
)
|
232 |
+
self.concurrency_limit = concurrency_limit
|
233 |
+
self.fn = fn
|
234 |
+
self.render_additional_inputs_fn = render_additional_inputs_fn
|
235 |
+
self.is_async = inspect.iscoroutinefunction(
|
236 |
+
self.fn
|
237 |
+
) or inspect.isasyncgenfunction(self.fn)
|
238 |
+
self.is_generator = inspect.isgeneratorfunction(
|
239 |
+
self.fn
|
240 |
+
) or inspect.isasyncgenfunction(self.fn)
|
241 |
+
self.examples = examples
|
242 |
+
if self.space_id and cache_examples is None:
|
243 |
+
self.cache_examples = True
|
244 |
+
else:
|
245 |
+
self.cache_examples = cache_examples or False
|
246 |
+
self.buttons: list[Button | None] = []
|
247 |
+
|
248 |
+
if additional_inputs:
|
249 |
+
if not isinstance(additional_inputs, list):
|
250 |
+
additional_inputs = [additional_inputs]
|
251 |
+
self.additional_inputs = [
|
252 |
+
get_component_instance(i)
|
253 |
+
for i in additional_inputs # type: ignore
|
254 |
+
]
|
255 |
+
else:
|
256 |
+
self.additional_inputs = []
|
257 |
+
if additional_inputs_accordion_name is not None:
|
258 |
+
print(
|
259 |
+
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
260 |
+
)
|
261 |
+
self.additional_inputs_accordion_params = {
|
262 |
+
"label": additional_inputs_accordion_name
|
263 |
+
}
|
264 |
+
if additional_inputs_accordion is None:
|
265 |
+
self.additional_inputs_accordion_params = {
|
266 |
+
"label": "Additional Inputs",
|
267 |
+
"open": False,
|
268 |
+
}
|
269 |
+
elif isinstance(additional_inputs_accordion, str):
|
270 |
+
self.additional_inputs_accordion_params = {
|
271 |
+
"label": additional_inputs_accordion
|
272 |
+
}
|
273 |
+
elif isinstance(additional_inputs_accordion, Accordion):
|
274 |
+
self.additional_inputs_accordion_params = (
|
275 |
+
additional_inputs_accordion.recover_kwargs(
|
276 |
+
additional_inputs_accordion.get_config()
|
277 |
+
)
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
raise ValueError(
|
281 |
+
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
|
282 |
+
)
|
283 |
+
|
284 |
+
with self:
|
285 |
+
if title:
|
286 |
+
Markdown(
|
287 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
288 |
+
)
|
289 |
+
if description:
|
290 |
+
Markdown(description)
|
291 |
+
|
292 |
+
if chatbot:
|
293 |
+
self.chatbot = chatbot.render()
|
294 |
+
else:
|
295 |
+
self.chatbot = Chatbot(
|
296 |
+
label="Chatbot", scale=1, height=200 if fill_height else None
|
297 |
+
)
|
298 |
+
|
299 |
+
with Row():
|
300 |
+
for btn in [retry_btn, undo_btn, clear_btn]:
|
301 |
+
if btn is not None:
|
302 |
+
if isinstance(btn, Button):
|
303 |
+
btn.render()
|
304 |
+
elif isinstance(btn, str):
|
305 |
+
btn = Button(btn, variant="secondary", size="sm")
|
306 |
+
else:
|
307 |
+
raise ValueError(
|
308 |
+
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
309 |
+
)
|
310 |
+
self.buttons.append(btn) # type: ignore
|
311 |
+
|
312 |
+
with Group():
|
313 |
+
with Row():
|
314 |
+
if textbox:
|
315 |
+
textbox.container = False
|
316 |
+
textbox.show_label = False
|
317 |
+
textbox_ = textbox.render()
|
318 |
+
assert isinstance(textbox_, Textbox)
|
319 |
+
self.textbox = textbox_
|
320 |
+
else:
|
321 |
+
self.textbox = Textbox(
|
322 |
+
container=False,
|
323 |
+
show_label=False,
|
324 |
+
label="Message",
|
325 |
+
placeholder="Type a message...",
|
326 |
+
scale=7,
|
327 |
+
autofocus=autofocus,
|
328 |
+
)
|
329 |
+
if submit_btn is not None:
|
330 |
+
if isinstance(submit_btn, Button):
|
331 |
+
submit_btn.render()
|
332 |
+
elif isinstance(submit_btn, str):
|
333 |
+
submit_btn = Button(
|
334 |
+
submit_btn,
|
335 |
+
variant="primary",
|
336 |
+
scale=2,
|
337 |
+
min_width=150,
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
raise ValueError(
|
341 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
342 |
+
)
|
343 |
+
if stop_btn is not None:
|
344 |
+
if isinstance(stop_btn, Button):
|
345 |
+
stop_btn.visible = False
|
346 |
+
stop_btn.render()
|
347 |
+
elif isinstance(stop_btn, str):
|
348 |
+
stop_btn = Button(
|
349 |
+
stop_btn,
|
350 |
+
variant="stop",
|
351 |
+
visible=False,
|
352 |
+
scale=2,
|
353 |
+
min_width=150,
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
raise ValueError(
|
357 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
358 |
+
)
|
359 |
+
self.num_tokens = Textbox(
|
360 |
+
container=False,
|
361 |
+
label="num_tokens",
|
362 |
+
placeholder="0 tokens",
|
363 |
+
scale=1,
|
364 |
+
interactive=False,
|
365 |
+
# autofocus=autofocus,
|
366 |
+
min_width=10
|
367 |
+
)
|
368 |
+
self.buttons.extend([submit_btn, stop_btn]) # type: ignore
|
369 |
+
|
370 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
371 |
+
self.fake_response_textbox = Textbox(label="Response", visible=False)
|
372 |
+
(
|
373 |
+
self.retry_btn,
|
374 |
+
self.undo_btn,
|
375 |
+
self.clear_btn,
|
376 |
+
self.submit_btn,
|
377 |
+
self.stop_btn,
|
378 |
+
) = self.buttons
|
379 |
+
|
380 |
+
if examples:
|
381 |
+
if self.is_generator:
|
382 |
+
examples_fn = self._examples_stream_fn
|
383 |
+
else:
|
384 |
+
examples_fn = self._examples_fn
|
385 |
+
|
386 |
+
self.examples_handler = Examples(
|
387 |
+
examples=examples,
|
388 |
+
inputs=[self.textbox] + self.additional_inputs,
|
389 |
+
outputs=self.chatbot,
|
390 |
+
fn=examples_fn,
|
391 |
+
)
|
392 |
+
|
393 |
+
any_unrendered_inputs = any(
|
394 |
+
not inp.is_rendered for inp in self.additional_inputs
|
395 |
+
)
|
396 |
+
if self.additional_inputs and any_unrendered_inputs:
|
397 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
398 |
+
if self.render_additional_inputs_fn is not None:
|
399 |
+
self.render_additional_inputs_fn()
|
400 |
+
else:
|
401 |
+
for input_component in self.additional_inputs:
|
402 |
+
if not input_component.is_rendered:
|
403 |
+
input_component.render()
|
404 |
+
|
405 |
+
self.rag_content = gr.Textbox(
|
406 |
+
scale=4,
|
407 |
+
lines=16,
|
408 |
+
label='Retrieved RAG context',
|
409 |
+
placeholder="Rag context and instrution will show up here",
|
410 |
+
interactive=False
|
411 |
+
)
|
412 |
+
|
413 |
+
# The example caching must happen after the input components have rendered
|
414 |
+
if cache_examples:
|
415 |
+
client_utils.synchronize_async(self.examples_handler.cache)
|
416 |
+
|
417 |
+
self.saved_input = State()
|
418 |
+
self.chatbot_state = (
|
419 |
+
State(self.chatbot.value) if self.chatbot.value else State([])
|
420 |
+
)
|
421 |
+
|
422 |
+
self._setup_events()
|
423 |
+
self._setup_api()
|
424 |
+
|
425 |
+
def _setup_events(self) -> None:
|
426 |
+
from gradio.components import State
|
427 |
+
has_on = False
|
428 |
+
try:
|
429 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
430 |
+
has_on = True
|
431 |
+
except ImportError as ie:
|
432 |
+
has_on = False
|
433 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
434 |
+
if not self.is_generator:
|
435 |
+
raise NotImplementedError(f'should use generator')
|
436 |
+
|
437 |
+
if has_on:
|
438 |
+
# new version
|
439 |
+
submit_triggers = (
|
440 |
+
[self.textbox.submit, self.submit_btn.click]
|
441 |
+
if self.submit_btn
|
442 |
+
else [self.textbox.submit]
|
443 |
+
)
|
444 |
+
submit_event = (
|
445 |
+
on(
|
446 |
+
submit_triggers,
|
447 |
+
self._clear_and_save_textbox,
|
448 |
+
[self.textbox],
|
449 |
+
[self.textbox, self.saved_input],
|
450 |
+
api_name=False,
|
451 |
+
queue=False,
|
452 |
+
)
|
453 |
+
.then(
|
454 |
+
self._display_input,
|
455 |
+
[self.saved_input, self.chatbot_state],
|
456 |
+
[self.chatbot, self.chatbot_state],
|
457 |
+
api_name=False,
|
458 |
+
queue=False,
|
459 |
+
)
|
460 |
+
.then(
|
461 |
+
submit_fn,
|
462 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
463 |
+
[self.chatbot, self.chatbot_state, self.num_tokens, self.rag_content],
|
464 |
+
api_name=False,
|
465 |
+
)
|
466 |
+
)
|
467 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
468 |
+
else:
|
469 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
470 |
+
|
471 |
+
if self.retry_btn:
|
472 |
+
retry_event = (
|
473 |
+
self.retry_btn.click(
|
474 |
+
self._delete_prev_fn,
|
475 |
+
[self.chatbot_state],
|
476 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
477 |
+
api_name=False,
|
478 |
+
queue=False,
|
479 |
+
)
|
480 |
+
.then(
|
481 |
+
self._display_input,
|
482 |
+
[self.saved_input, self.chatbot_state],
|
483 |
+
[self.chatbot, self.chatbot_state],
|
484 |
+
api_name=False,
|
485 |
+
queue=False,
|
486 |
+
)
|
487 |
+
.then(
|
488 |
+
submit_fn,
|
489 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
490 |
+
[self.chatbot, self.chatbot_state, self.num_tokens, self.rag_content],
|
491 |
+
api_name=False,
|
492 |
+
)
|
493 |
+
)
|
494 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
495 |
+
|
496 |
+
if self.undo_btn:
|
497 |
+
self.undo_btn.click(
|
498 |
+
self._delete_prev_fn,
|
499 |
+
[self.chatbot_state],
|
500 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
501 |
+
api_name=False,
|
502 |
+
queue=False,
|
503 |
+
).then(
|
504 |
+
lambda x: x,
|
505 |
+
[self.saved_input],
|
506 |
+
[self.textbox],
|
507 |
+
api_name=False,
|
508 |
+
queue=False,
|
509 |
+
)
|
510 |
+
# Reconfigure clear_btn to stop and clear text box
|
511 |
+
|
512 |
+
async def _stream_fn(
|
513 |
+
self,
|
514 |
+
message: str,
|
515 |
+
history_with_input,
|
516 |
+
request: Request,
|
517 |
+
*args,
|
518 |
+
) -> AsyncGenerator:
|
519 |
+
history = history_with_input[:-1]
|
520 |
+
inputs, _, _ = special_args(
|
521 |
+
self.fn, inputs=[message, history, *args], request=request
|
522 |
+
)
|
523 |
+
|
524 |
+
if self.is_async:
|
525 |
+
generator = self.fn(*inputs)
|
526 |
+
else:
|
527 |
+
generator = await anyio.to_thread.run_sync(
|
528 |
+
self.fn, *inputs, limiter=self.limiter
|
529 |
+
)
|
530 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
531 |
+
|
532 |
+
# ! In case of error, yield the previous history & undo any generation before raising error
|
533 |
+
try:
|
534 |
+
first_response_pack = await async_iteration(generator)
|
535 |
+
if isinstance(first_response_pack, (tuple, list)):
|
536 |
+
first_response, num_tokens, rag_content = first_response_pack
|
537 |
+
else:
|
538 |
+
first_response, num_tokens, rag_content = first_response_pack, -1, ""
|
539 |
+
update = history + [[message, first_response]]
|
540 |
+
yield update, update, f"{num_tokens} toks", rag_content
|
541 |
+
except StopIteration:
|
542 |
+
update = history + [[message, None]]
|
543 |
+
yield update, update, "NaN toks", ""
|
544 |
+
except Exception as e:
|
545 |
+
yield history, history, "NaN toks", ""
|
546 |
+
raise e
|
547 |
+
|
548 |
+
try:
|
549 |
+
async for response_pack in generator:
|
550 |
+
if isinstance(response_pack, (tuple, list)):
|
551 |
+
response, num_tokens, rag_content = response_pack
|
552 |
+
else:
|
553 |
+
response, num_tokens, rag_content = response_pack, "NaN toks", ""
|
554 |
+
update = history + [[message, response]]
|
555 |
+
yield update, update, f"{num_tokens} toks", rag_content
|
556 |
+
except Exception as e:
|
557 |
+
yield history, history, "NaN toks", ""
|
558 |
+
raise e
|
559 |
+
|
560 |
+
|
561 |
+
|
562 |
+
@register_demo
|
563 |
+
class RagChatInterfaceDemo(ChatInterfaceDemo):
|
564 |
+
|
565 |
+
@property
|
566 |
+
def examples(self):
|
567 |
+
return [
|
568 |
+
["Explain how attention works.", "assets/attention_all_you_need.pdf"],
|
569 |
+
["Explain why the sky is blue.", None],
|
570 |
+
]
|
571 |
+
|
572 |
+
@property
|
573 |
+
def tab_name(self):
|
574 |
+
return "RAG Chat"
|
575 |
+
|
576 |
+
def create_demo(
|
577 |
+
self,
|
578 |
+
title: str | None = None,
|
579 |
+
description: str | None = None,
|
580 |
+
**kwargs
|
581 |
+
) -> gr.Blocks:
|
582 |
+
load_embeddings()
|
583 |
+
global RAG_EMBED
|
584 |
+
# assert RAG_EMBED is not None
|
585 |
+
print(F'{RAG_EMBED=}')
|
586 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
587 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
588 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
589 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
590 |
+
rag_num_docs = kwargs.get("rag_num_docs", 3)
|
591 |
+
|
592 |
+
from ..configs import RAG_EMBED_MODEL_NAME
|
593 |
+
|
594 |
+
description = (
|
595 |
+
description or
|
596 |
+
f"""Upload a long document to ask question with RAG. Check at the bottom the retrieved RAG text segment.
|
597 |
+
Control `RAG instruction to fit your language`. Embedding model {RAG_EMBED_MODEL_NAME}."""
|
598 |
+
)
|
599 |
+
|
600 |
+
additional_inputs = [
|
601 |
+
gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt']),
|
602 |
+
gr.Number(value=temperature, label='Temperature', min_width=20),
|
603 |
+
gr.Number(value=max_tokens, label='Max tokens', min_width=20),
|
604 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=2),
|
605 |
+
gr.Number(value=rag_num_docs, label='RAG Top-K', min_width=20),
|
606 |
+
gr.Textbox(value=DOC_INSTRUCTION, label='RAG instruction'),
|
607 |
+
]
|
608 |
+
def render_additional_inputs_fn():
|
609 |
+
additional_inputs[0].render()
|
610 |
+
with Row():
|
611 |
+
additional_inputs[1].render()
|
612 |
+
additional_inputs[2].render()
|
613 |
+
additional_inputs[4].render()
|
614 |
+
additional_inputs[3].render()
|
615 |
+
additional_inputs[5].render()
|
616 |
+
|
617 |
+
demo_chat = RagChatInterface(
|
618 |
+
chat_response_stream_multiturn_doc_engine,
|
619 |
+
chatbot=gr.Chatbot(
|
620 |
+
label=model_name,
|
621 |
+
bubble_full_width=False,
|
622 |
+
latex_delimiters=[
|
623 |
+
{ "left": "$", "right": "$", "display": False},
|
624 |
+
{ "left": "$$", "right": "$$", "display": True},
|
625 |
+
],
|
626 |
+
show_copy_button=True,
|
627 |
+
),
|
628 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
629 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
630 |
+
# ! consider preventing the stop button
|
631 |
+
# stop_btn=None,
|
632 |
+
title=title,
|
633 |
+
description=description,
|
634 |
+
additional_inputs=additional_inputs,
|
635 |
+
render_additional_inputs_fn=render_additional_inputs_fn,
|
636 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
637 |
+
examples=self.examples,
|
638 |
+
cache_examples=False,
|
639 |
+
)
|
640 |
+
return demo_chat
|
641 |
+
|
642 |
+
|
multipurpose_chatbot/demos/text_completion.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
|
25 |
+
|
26 |
+
import inspect
|
27 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
28 |
+
|
29 |
+
import anyio
|
30 |
+
from gradio_client import utils as client_utils
|
31 |
+
from gradio_client.documentation import document
|
32 |
+
|
33 |
+
from gradio.blocks import Blocks
|
34 |
+
from gradio.components import (
|
35 |
+
Button,
|
36 |
+
Chatbot,
|
37 |
+
Component,
|
38 |
+
Markdown,
|
39 |
+
State,
|
40 |
+
Textbox,
|
41 |
+
get_component_instance,
|
42 |
+
)
|
43 |
+
from gradio.events import Dependency, on
|
44 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
45 |
+
from gradio.helpers import special_args
|
46 |
+
from gradio.layouts import Accordion, Group, Row
|
47 |
+
from gradio.routes import Request
|
48 |
+
from gradio.themes import ThemeClass as Theme
|
49 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
50 |
+
|
51 |
+
|
52 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
53 |
+
|
54 |
+
|
55 |
+
from ..configs import (
|
56 |
+
SYSTEM_PROMPT,
|
57 |
+
MODEL_NAME,
|
58 |
+
MAX_TOKENS,
|
59 |
+
TEMPERATURE,
|
60 |
+
)
|
61 |
+
|
62 |
+
from ..globals import MODEL_ENGINE
|
63 |
+
|
64 |
+
|
65 |
+
def generate_text_completion_stream_engine(
|
66 |
+
message: str,
|
67 |
+
temperature: float,
|
68 |
+
max_tokens: int,
|
69 |
+
stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
|
70 |
+
):
|
71 |
+
global MODEL_ENGINE
|
72 |
+
temperature = float(temperature)
|
73 |
+
# ! remove frequency_penalty
|
74 |
+
# frequency_penalty = float(frequency_penalty)
|
75 |
+
max_tokens = int(max_tokens)
|
76 |
+
# message = message.strip()
|
77 |
+
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
78 |
+
stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>', '<|im_end|>']))
|
79 |
+
if message.strip() != message:
|
80 |
+
gr.Warning(f'There are preceding/trailing spaces in the message, may lead to unexpected behavior')
|
81 |
+
if len(message) == 0:
|
82 |
+
raise gr.Error("The message cannot be empty!")
|
83 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(message))
|
84 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
85 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
86 |
+
|
87 |
+
outputs = None
|
88 |
+
response = None
|
89 |
+
num_tokens = -1
|
90 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
91 |
+
prompt=message,
|
92 |
+
temperature=temperature,
|
93 |
+
max_tokens=max_tokens,
|
94 |
+
stop_strings=stop_strings,
|
95 |
+
)):
|
96 |
+
if isinstance(outputs, tuple):
|
97 |
+
response, num_tokens = outputs
|
98 |
+
else:
|
99 |
+
response, num_tokens = outputs, -1
|
100 |
+
yield message + response, f"{num_tokens} tokens"
|
101 |
+
|
102 |
+
if response is not None:
|
103 |
+
yield message + response, f"{num_tokens} tokens"
|
104 |
+
|
105 |
+
|
106 |
+
@register_demo
|
107 |
+
class TextCompletionDemo(BaseDemo):
|
108 |
+
@property
|
109 |
+
def tab_name(self):
|
110 |
+
return "Text Completion"
|
111 |
+
|
112 |
+
def create_demo(
|
113 |
+
self,
|
114 |
+
title: str | None = None,
|
115 |
+
description: str | None = None,
|
116 |
+
**kwargs
|
117 |
+
) -> gr.Blocks:
|
118 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
119 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
120 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
121 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
122 |
+
# frequence_penalty = FREQUENCE_PENALTY
|
123 |
+
# presence_penalty = PRESENCE_PENALTY
|
124 |
+
max_tokens = max_tokens // 2
|
125 |
+
|
126 |
+
description = description or f"""Put any context string (like few-shot prompts)"""
|
127 |
+
|
128 |
+
with gr.Blocks() as demo_text_completion:
|
129 |
+
if title:
|
130 |
+
gr.Markdown(title)
|
131 |
+
if description:
|
132 |
+
gr.Markdown(description)
|
133 |
+
with gr.Row():
|
134 |
+
txt = gr.Textbox(
|
135 |
+
scale=4,
|
136 |
+
lines=16,
|
137 |
+
show_label=False,
|
138 |
+
placeholder="Enter any free form text and submit",
|
139 |
+
container=False,
|
140 |
+
)
|
141 |
+
with gr.Row():
|
142 |
+
submit_button = gr.Button('Submit', variant='primary', scale=9)
|
143 |
+
stop_button = gr.Button('Stop', variant='stop', scale=9, visible=False)
|
144 |
+
num_tokens = Textbox(
|
145 |
+
container=False,
|
146 |
+
show_label=False,
|
147 |
+
label="num_tokens",
|
148 |
+
placeholder="0 tokens",
|
149 |
+
scale=1,
|
150 |
+
interactive=False,
|
151 |
+
min_width=10
|
152 |
+
)
|
153 |
+
with gr.Row():
|
154 |
+
temp_input = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
|
155 |
+
length_input = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
|
156 |
+
stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>,<|im_end|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
|
157 |
+
examples = gr.Examples(
|
158 |
+
examples=[
|
159 |
+
["The following is the recite the declaration of independence:",]
|
160 |
+
],
|
161 |
+
inputs=[txt, temp_input, length_input, stop_strings],
|
162 |
+
# outputs=[txt]
|
163 |
+
)
|
164 |
+
# ! Handle stop button
|
165 |
+
submit_trigger = submit_button.click
|
166 |
+
submit_event = submit_button.click(
|
167 |
+
# submit_trigger,
|
168 |
+
generate_text_completion_stream_engine,
|
169 |
+
[txt, temp_input, length_input, stop_strings],
|
170 |
+
[txt, num_tokens],
|
171 |
+
# api_name=False,
|
172 |
+
# queue=False,
|
173 |
+
)
|
174 |
+
|
175 |
+
submit_trigger(
|
176 |
+
lambda: (
|
177 |
+
Button(visible=False), Button(visible=True),
|
178 |
+
),
|
179 |
+
None,
|
180 |
+
[submit_button, stop_button],
|
181 |
+
api_name=False,
|
182 |
+
queue=False,
|
183 |
+
)
|
184 |
+
submit_event.then(
|
185 |
+
lambda: (Button(visible=True), Button(visible=False)),
|
186 |
+
None,
|
187 |
+
[submit_button, stop_button],
|
188 |
+
api_name=False,
|
189 |
+
queue=False,
|
190 |
+
)
|
191 |
+
stop_button.click(
|
192 |
+
None,
|
193 |
+
None,
|
194 |
+
None,
|
195 |
+
cancels=submit_event,
|
196 |
+
api_name=False,
|
197 |
+
)
|
198 |
+
|
199 |
+
return demo_text_completion
|
multipurpose_chatbot/engines/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
multipurpose_chatbot/engines/__init__.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .base_engine import BaseEngine
|
3 |
+
|
4 |
+
BACKENDS = [
|
5 |
+
"mlx",
|
6 |
+
"vllm",
|
7 |
+
"transformers",
|
8 |
+
"llava15_transformers",
|
9 |
+
"llama_cpp",
|
10 |
+
# "llava_llama_cpp",
|
11 |
+
"debug",
|
12 |
+
]
|
13 |
+
|
14 |
+
ENGINE_LOADED = False
|
15 |
+
|
16 |
+
|
17 |
+
def load_multipurpose_chatbot_engine(backend: str):
|
18 |
+
# ! lazy import other engines
|
19 |
+
global ENGINE_LOADED
|
20 |
+
assert backend in BACKENDS, f'{backend} not in {BACKENDS}'
|
21 |
+
if ENGINE_LOADED:
|
22 |
+
raise RuntimeError(f'{ENGINE_LOADED=} this means load_multipurpose_chatbot_engine has already been called! Check your codes.')
|
23 |
+
print(f'Load model from {backend}')
|
24 |
+
if backend == "mlx":
|
25 |
+
from .mlx_engine import MlxEngine
|
26 |
+
model_engine = MlxEngine()
|
27 |
+
elif backend == 'vllm':
|
28 |
+
from .vllm_engine import VllmEngine
|
29 |
+
model_engine = VllmEngine()
|
30 |
+
elif backend == 'transformers':
|
31 |
+
from .transformers_engine import TransformersEngine
|
32 |
+
model_engine = TransformersEngine()
|
33 |
+
elif backend == 'llava15_transformers':
|
34 |
+
from .llava15_transformers_engine import Llava15TransformersEngine
|
35 |
+
model_engine = Llava15TransformersEngine()
|
36 |
+
elif backend == 'llama_cpp':
|
37 |
+
from .llama_cpp_engine import LlamaCppEngine
|
38 |
+
model_engine = LlamaCppEngine()
|
39 |
+
# ! llava_llama_cpp currently not done due to bugs
|
40 |
+
# elif backend == 'llava_llama_cpp':
|
41 |
+
# from .llava_llama_cpp_engine import LlavaLlamaCppEngine
|
42 |
+
# model_engine = LlavaLlamaCppEngine()
|
43 |
+
elif backend == 'debug':
|
44 |
+
from .debug_engine import DebugEngine
|
45 |
+
model_engine = DebugEngine()
|
46 |
+
else:
|
47 |
+
raise ValueError(f'backend invalid: {BACKENDS} vs {backend}')
|
48 |
+
|
49 |
+
model_engine.load_model()
|
50 |
+
ENGINE_LOADED = True
|
51 |
+
return model_engine
|
52 |
+
# ! add more llama.cpp engine here.
|
53 |
+
|
54 |
+
|
multipurpose_chatbot/engines/base_engine.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
# ! Avoid importing transformers
|
5 |
+
# from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
6 |
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
7 |
+
import time
|
8 |
+
|
9 |
+
|
10 |
+
class BaseEngine(object):
|
11 |
+
def __init__(self, **kwargs) -> None:
|
12 |
+
pass
|
13 |
+
|
14 |
+
@property
|
15 |
+
def max_position_embeddings(self) -> int:
|
16 |
+
return 10000
|
17 |
+
|
18 |
+
@property
|
19 |
+
def tokenizer(self):
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
@property
|
23 |
+
def processor(self):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def load_model(self, ):
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
|
30 |
+
"""
|
31 |
+
return string convo, add_special_tokens should be added later
|
32 |
+
"""
|
33 |
+
bos_token = self.tokenizer.bos_token
|
34 |
+
eos_token = self.tokenizer.eos_token
|
35 |
+
if not add_special_tokens:
|
36 |
+
# prevent bos being added to string
|
37 |
+
self.tokenizer.bos_token = ""
|
38 |
+
self.tokenizer.eos_token = ""
|
39 |
+
full_prompt = self.tokenizer.apply_chat_template(
|
40 |
+
conversations, add_generation_prompt=add_generation_prompt,
|
41 |
+
tokenize=False,
|
42 |
+
)
|
43 |
+
self.tokenizer.bos_token = bos_token
|
44 |
+
self.tokenizer.eos_token = eos_token
|
45 |
+
return full_prompt
|
46 |
+
|
multipurpose_chatbot/engines/debug_engine.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
5 |
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
6 |
+
import time
|
7 |
+
|
8 |
+
from .base_engine import BaseEngine
|
9 |
+
|
10 |
+
from ..configs import (
|
11 |
+
MODEL_PATH,
|
12 |
+
)
|
13 |
+
|
14 |
+
FAKE_MODEL_PATH = os.environ.get("FAKE_MODEL_PATH", MODEL_PATH)
|
15 |
+
FAKE_RESPONSE = "Wow that's very very cool, please try again."
|
16 |
+
|
17 |
+
|
18 |
+
class DebugEngine(BaseEngine):
|
19 |
+
"""
|
20 |
+
It will always yield FAKE_RESPONSE
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, **kwargs) -> None:
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
self._model = None
|
26 |
+
self._tokenizer = None
|
27 |
+
|
28 |
+
@property
|
29 |
+
def tokenizer(self) -> PreTrainedTokenizer:
|
30 |
+
if self._tokenizer is None:
|
31 |
+
self._tokenizer = AutoTokenizer.from_pretrained(FAKE_MODEL_PATH, trust_remote_code=True)
|
32 |
+
return self._tokenizer
|
33 |
+
|
34 |
+
def load_model(self):
|
35 |
+
print(f"Load fake model with tokenizer: {self.tokenizer}")
|
36 |
+
|
37 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
38 |
+
|
39 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
40 |
+
response = FAKE_RESPONSE
|
41 |
+
for i in range(len(response)):
|
42 |
+
time.sleep(0.01)
|
43 |
+
yield response[:i], num_tokens
|
44 |
+
|
45 |
+
num_tokens = len(self.tokenizer.encode(prompt + response))
|
46 |
+
yield response, num_tokens
|
47 |
+
|
48 |
+
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
49 |
+
return [p + " -- Test" for p in prompts]
|
multipurpose_chatbot/engines/llama_cpp_engine.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from typing import Any, Iterator
|
6 |
+
from typing import Iterator, List, Optional, Tuple
|
7 |
+
import filelock
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
from gradio.routes import Request
|
12 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
13 |
+
from gradio.helpers import special_args
|
14 |
+
import anyio
|
15 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
16 |
+
|
17 |
+
from gradio_client.documentation import document, set_documentation_group
|
18 |
+
|
19 |
+
from typing import List, Optional, Union, Dict, Tuple
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
import types
|
23 |
+
|
24 |
+
from gradio.components import Button
|
25 |
+
from gradio.events import Dependency, EventListenerMethod
|
26 |
+
|
27 |
+
import types
|
28 |
+
import sys
|
29 |
+
|
30 |
+
from .base_engine import BaseEngine
|
31 |
+
|
32 |
+
# ! Remember to use static cache
|
33 |
+
|
34 |
+
from ..configs import (
|
35 |
+
MODEL_PATH,
|
36 |
+
DEFAULT_CHAT_TEMPLATE,
|
37 |
+
N_CTX,
|
38 |
+
N_GPU_LAYERS,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def encode_tokenize(self, prompt: str, **kwargs):
|
44 |
+
"""Mimic behavior of transformers tokenizer"""
|
45 |
+
prompt_tokens: List[int] = (
|
46 |
+
(
|
47 |
+
self.tokenize(prompt.encode("utf-8"), special=True)
|
48 |
+
if prompt != ""
|
49 |
+
else [self.token_bos()]
|
50 |
+
)
|
51 |
+
if isinstance(prompt, str)
|
52 |
+
else prompt
|
53 |
+
)
|
54 |
+
return prompt_tokens
|
55 |
+
|
56 |
+
|
57 |
+
conversations = [
|
58 |
+
{"role": "system", "content": "You are good."},
|
59 |
+
{"role": "user", "content": "Hello."},
|
60 |
+
{"role": "assistant", "content": "Hi."},
|
61 |
+
]
|
62 |
+
|
63 |
+
|
64 |
+
class LlamaCppEngine(BaseEngine):
|
65 |
+
"""
|
66 |
+
need to create an engine.tokenizer.encode(text) method
|
67 |
+
"""
|
68 |
+
@property
|
69 |
+
def max_position_embeddings(self) -> int:
|
70 |
+
# raise ValueError
|
71 |
+
return self._model.context_params.n_ctx
|
72 |
+
|
73 |
+
def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
|
74 |
+
"""
|
75 |
+
return string convo, add_special_tokens should be added later
|
76 |
+
remember to remove <s> if any,
|
77 |
+
"""
|
78 |
+
from llama_cpp.llama_chat_format import Jinja2ChatFormatter
|
79 |
+
|
80 |
+
formatter = Jinja2ChatFormatter(
|
81 |
+
template=self._model.metadata['tokenizer.chat_template'],
|
82 |
+
# bos_token=self._model._model.token_get_text(self._model.token_bos()),
|
83 |
+
bos_token="",
|
84 |
+
eos_token=self._model._model.token_get_text(self._model.token_eos()),
|
85 |
+
add_generation_prompt=add_generation_prompt,
|
86 |
+
)
|
87 |
+
|
88 |
+
full_prompt = formatter(messages=conversations).prompt
|
89 |
+
# ! it may has bos
|
90 |
+
return full_prompt
|
91 |
+
|
92 |
+
@property
|
93 |
+
def tokenizer(self):
|
94 |
+
return self._model
|
95 |
+
|
96 |
+
def load_model(self):
|
97 |
+
# from transformers import AutoTokenizer, AutoModelForCausalLM
|
98 |
+
|
99 |
+
from llama_cpp import Llama
|
100 |
+
self.model_path = MODEL_PATH
|
101 |
+
self._model = Llama(
|
102 |
+
model_path=self.model_path,
|
103 |
+
n_gpu_layers=N_GPU_LAYERS, # Uncomment to use GPU acceleration
|
104 |
+
# seed=1337, # Uncomment to set a specific seed
|
105 |
+
n_ctx=N_CTX, # Uncomment to increase the context window
|
106 |
+
)
|
107 |
+
self._tokenizer = self._model
|
108 |
+
self._model.encode = types.MethodType(encode_tokenize, self._model)
|
109 |
+
print(f'Load model: {self.model_path=} | {N_GPU_LAYERS=} | {N_CTX=}')
|
110 |
+
|
111 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
112 |
+
stop_strings = list(stop_strings) if stop_strings is not None else []
|
113 |
+
stop_strings = list(set(stop_strings + ["</s>", "<|im_end|>"]))
|
114 |
+
generator = self._model(
|
115 |
+
prompt,
|
116 |
+
max_tokens=max_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
|
117 |
+
temperature=temperature,
|
118 |
+
stop=stop_strings, # Stop generating just before the model would generate a new question
|
119 |
+
stream=True,
|
120 |
+
)
|
121 |
+
response = ""
|
122 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
123 |
+
for g in generator:
|
124 |
+
response += g['choices'][0]['text']
|
125 |
+
yield response, num_tokens
|
126 |
+
|
127 |
+
if response is not None and len(response) > 0:
|
128 |
+
num_tokens = len(self.tokenizer.encode(prompt + response))
|
129 |
+
yield response, num_tokens
|
130 |
+
|
131 |
+
|
multipurpose_chatbot/engines/llava15_transformers_engine.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
from typing import Any, Iterator
|
8 |
+
from typing import Iterator, List, Optional, Tuple
|
9 |
+
import filelock
|
10 |
+
import glob
|
11 |
+
import json
|
12 |
+
import time
|
13 |
+
from gradio.routes import Request
|
14 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
15 |
+
from gradio.helpers import special_args
|
16 |
+
import anyio
|
17 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
18 |
+
|
19 |
+
from gradio_client.documentation import document, set_documentation_group
|
20 |
+
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
|
25 |
+
from gradio.components import Button
|
26 |
+
from gradio.events import Dependency, EventListenerMethod
|
27 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
28 |
+
import types
|
29 |
+
import sys
|
30 |
+
from .base_engine import BaseEngine
|
31 |
+
from .transformers_engine import TransformersEngine, NewGenerationMixin
|
32 |
+
|
33 |
+
from ..configs import (
|
34 |
+
STREAM_CHECK_MULTIPLE,
|
35 |
+
STREAM_YIELD_MULTIPLE,
|
36 |
+
)
|
37 |
+
|
38 |
+
from ..configs import (
|
39 |
+
STREAM_CHECK_MULTIPLE,
|
40 |
+
STREAM_YIELD_MULTIPLE,
|
41 |
+
IMAGE_TOKEN,
|
42 |
+
IMAGE_TOKEN_INTERACTIVE,
|
43 |
+
IMAGE_TOKEN_LENGTH,
|
44 |
+
MAX_PACHES,
|
45 |
+
DTYPE,
|
46 |
+
DEVICE,
|
47 |
+
)
|
48 |
+
|
49 |
+
CODE_PATH = os.environ.get("CODE_PATH", "")
|
50 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "")
|
51 |
+
|
52 |
+
# IMAGE_TOKEN = "<image"
|
53 |
+
|
54 |
+
# IMAGE_LENGTH = 576
|
55 |
+
# MAX_PACHES = 1
|
56 |
+
|
57 |
+
|
58 |
+
# ! Still working on it....
|
59 |
+
# Should only do with
|
60 |
+
|
61 |
+
"""
|
62 |
+
This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers. 这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。
|
63 |
+
|
64 |
+
### Human: <image_placeholder>
|
65 |
+
Describe the cats and what they are doing in detail.
|
66 |
+
### Assistant:
|
67 |
+
"""
|
68 |
+
|
69 |
+
# prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
|
70 |
+
# image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
71 |
+
|
72 |
+
# conv_llava_llama_2 = Conversation(
|
73 |
+
# system="You are a helpful language and vision assistant. "
|
74 |
+
# "You are able to understand the visual content that the user provides, "
|
75 |
+
# "and assist the user with a variety of tasks using natural language.",
|
76 |
+
# roles=("USER", "ASSISTANT"),
|
77 |
+
# version="llama_v2",
|
78 |
+
# messages=(),
|
79 |
+
# offset=0,
|
80 |
+
# sep_style=SeparatorStyle.LLAMA_2,
|
81 |
+
# sep="<s>",
|
82 |
+
# sep2="</s>",
|
83 |
+
# )
|
84 |
+
|
85 |
+
|
86 |
+
LLAVA_CHAT_TEMPLATE = """"""
|
87 |
+
|
88 |
+
|
89 |
+
# "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '</s>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
90 |
+
|
91 |
+
|
92 |
+
if IMAGE_TOKEN != "<image>":
|
93 |
+
print(f'WARNING!!!! {IMAGE_TOKEN=} is not <image>, this can lead to problems')
|
94 |
+
|
95 |
+
|
96 |
+
class Llava15TransformersEngine(TransformersEngine):
|
97 |
+
"""
|
98 |
+
Llava 1.5 hardcoded
|
99 |
+
"""
|
100 |
+
@property
|
101 |
+
def image_token(self):
|
102 |
+
return IMAGE_TOKEN
|
103 |
+
|
104 |
+
@property
|
105 |
+
def max_position_embeddings(self) -> int:
|
106 |
+
return self._model.config.text_config.max_position_embeddings
|
107 |
+
|
108 |
+
@property
|
109 |
+
def tokenizer(self):
|
110 |
+
return self._tokenizer
|
111 |
+
|
112 |
+
@property
|
113 |
+
def processor(self):
|
114 |
+
return self._processor
|
115 |
+
|
116 |
+
|
117 |
+
def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
|
118 |
+
"""
|
119 |
+
return string convo, add_special_tokens should be added later
|
120 |
+
"""
|
121 |
+
prompt = ""
|
122 |
+
for turn in conversations:
|
123 |
+
if turn['role'] == 'system':
|
124 |
+
prompt += turn['content'] + "\n\n"
|
125 |
+
elif turn['role'] == 'user':
|
126 |
+
prompt += f"USER: {turn['content']}\n"
|
127 |
+
elif turn['role'] == 'assistant':
|
128 |
+
prompt += f"ASSISTANT: {turn['content']}\n"
|
129 |
+
if add_generation_prompt:
|
130 |
+
prompt += f"ASSISTANT:"
|
131 |
+
return prompt
|
132 |
+
|
133 |
+
|
134 |
+
def load_model(self):
|
135 |
+
import requests
|
136 |
+
from PIL import Image
|
137 |
+
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
138 |
+
|
139 |
+
self.model_path = model_path = MODEL_PATH
|
140 |
+
self.torch_dtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16
|
141 |
+
self.device_map = DEVICE
|
142 |
+
print(f'Loading model from {model_path} on {self.device_map} with {self.torch_dtype} | LlavaForConditionalGeneration')
|
143 |
+
|
144 |
+
self._processor = AutoProcessor.from_pretrained(self.model_path)
|
145 |
+
self._model = LlavaForConditionalGeneration.from_pretrained(
|
146 |
+
MODEL_PATH,
|
147 |
+
torch_dtype=self.torch_dtype, device_map=self.device_map, trust_remote_code=True
|
148 |
+
).eval()
|
149 |
+
self._model.sample_old = self._model.sample
|
150 |
+
# self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
151 |
+
self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
152 |
+
|
153 |
+
self._tokenizer = self._processor.tokenizer
|
154 |
+
print(self._model)
|
155 |
+
print(f"{self.max_position_embeddings=}")
|
156 |
+
|
157 |
+
def get_multimodal_tokens(self, full_prompt, image_paths=None):
|
158 |
+
num_tokens = len(self.tokenizer.encode(full_prompt))
|
159 |
+
for image_path in image_paths:
|
160 |
+
num_tokens += IMAGE_TOKEN_LENGTH * MAX_PACHES
|
161 |
+
return num_tokens
|
162 |
+
|
163 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
164 |
+
from transformers.generation.utils import GenerationConfig
|
165 |
+
from PIL import Image
|
166 |
+
image_paths = kwargs.get("image_paths", None)
|
167 |
+
image_paths = image_paths or []
|
168 |
+
|
169 |
+
images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None
|
170 |
+
|
171 |
+
with torch.no_grad():
|
172 |
+
inputs = self.processor(prompt, images, return_tensors='pt')
|
173 |
+
# inputs = inputs.to("cuda", torch.bfloat16)
|
174 |
+
inputs = {k: v.to(self.device_map) for k, v in inputs.items() if v is not None}
|
175 |
+
num_tokens = self.get_multimodal_tokens(prompt, image_paths)
|
176 |
+
# non-streaming generation
|
177 |
+
# output = self._model.generate(
|
178 |
+
# **inputs,
|
179 |
+
# do_sample=True,
|
180 |
+
# temperature=temperature,
|
181 |
+
# max_new_tokens=max_tokens,
|
182 |
+
# pad_token_id=self.processor.tokenizer.pad_token_id,
|
183 |
+
# )
|
184 |
+
# # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True)
|
185 |
+
# full_output_text = self.processor.decode(output[0], skip_special_tokens=True)
|
186 |
+
# response = full_output_text.split("<|im_start|>assistant\n")[-1]
|
187 |
+
# num_tokens = self.get_multimodal_tokens(prompt + response, image_paths)
|
188 |
+
# print(prompt)
|
189 |
+
# print(response)
|
190 |
+
# print(num_tokens)
|
191 |
+
# yield response, num_tokens
|
192 |
+
|
193 |
+
# if i % 4 == 0 and i > 1:
|
194 |
+
# message_safety = safety_check(response)
|
195 |
+
# if message_safety is not None:
|
196 |
+
# history = undo_history(history)
|
197 |
+
# yield history, "", None
|
198 |
+
# raise gr.Error(message_safety)
|
199 |
+
|
200 |
+
# # ! streaming
|
201 |
+
generator = self._model.generate(
|
202 |
+
**inputs,
|
203 |
+
do_sample=True,
|
204 |
+
temperature=temperature,
|
205 |
+
max_new_tokens=max_tokens,
|
206 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
207 |
+
)
|
208 |
+
|
209 |
+
out_tokens = []
|
210 |
+
response = None
|
211 |
+
for index, token in enumerate(generator):
|
212 |
+
out_tokens.append(token.item())
|
213 |
+
response = self.processor.tokenizer.decode(out_tokens)
|
214 |
+
|
215 |
+
yield response, num_tokens
|
216 |
+
|
217 |
+
del generator
|
218 |
+
|
219 |
+
if response is not None:
|
220 |
+
|
221 |
+
full_text = prompt + response
|
222 |
+
num_tokens = self.get_multimodal_tokens(full_text, image_paths)
|
223 |
+
yield response, num_tokens
|
224 |
+
|
225 |
+
# raw_image = Image.open(requests.get(image_file, stream=True).raw)
|
226 |
+
# inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
multipurpose_chatbot/engines/llava_llama_cpp_engine.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from typing import Any, Iterator
|
6 |
+
from typing import Iterator, List, Optional, Tuple
|
7 |
+
import filelock
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
from gradio.routes import Request
|
12 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
13 |
+
from gradio.helpers import special_args
|
14 |
+
import anyio
|
15 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
16 |
+
|
17 |
+
from gradio_client.documentation import document, set_documentation_group
|
18 |
+
|
19 |
+
from typing import List, Optional, Union, Dict, Tuple
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
import types
|
23 |
+
|
24 |
+
from gradio.components import Button
|
25 |
+
from gradio.events import Dependency, EventListenerMethod
|
26 |
+
|
27 |
+
import types
|
28 |
+
import sys
|
29 |
+
|
30 |
+
from .base_engine import BaseEngine
|
31 |
+
|
32 |
+
# ! Remember to use static cache
|
33 |
+
|
34 |
+
from ..configs import (
|
35 |
+
MODEL_PATH,
|
36 |
+
DEFAULT_CHAT_TEMPLATE,
|
37 |
+
N_CTX,
|
38 |
+
N_GPU_LAYERS,
|
39 |
+
IMAGE_TOKEN,
|
40 |
+
IMAGE_TOKEN_INTERACTIVE,
|
41 |
+
IMAGE_TOKEN_LENGTH,
|
42 |
+
MAX_PACHES,
|
43 |
+
)
|
44 |
+
|
45 |
+
from .llama_cpp_engine import (
|
46 |
+
encode_tokenize,
|
47 |
+
LlamaCppEngine,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
# resource: https://llama-cpp-python.readthedocs.io/en/latest/#multi-modal-models
|
53 |
+
|
54 |
+
import base64
|
55 |
+
|
56 |
+
def image_to_base64_data_uri(file_path):
|
57 |
+
with open(file_path, "rb") as img_file:
|
58 |
+
base64_data = base64.b64encode(img_file.read()).decode('utf-8')
|
59 |
+
return f"data:image/png;base64,{base64_data}"
|
60 |
+
|
61 |
+
|
62 |
+
# file_path = 'file_path.png'
|
63 |
+
# data_uri = image_to_base64_data_uri(file_path)
|
64 |
+
|
65 |
+
# data_uri = image_to_base64_data_uri(file_path)
|
66 |
+
|
67 |
+
# messages = [
|
68 |
+
# {"role": "system", "content": "You are an assistant who perfectly describes images."},
|
69 |
+
# {
|
70 |
+
# "role": "user",
|
71 |
+
# "content": [
|
72 |
+
# {"type": "image_url", "image_url": {"url": data_uri }},
|
73 |
+
# {"type" : "text", "text": "Describe this image in detail please."}
|
74 |
+
# ]
|
75 |
+
# }
|
76 |
+
# ]
|
77 |
+
|
78 |
+
|
79 |
+
def llava_15_chat_handler_call(
|
80 |
+
self,
|
81 |
+
*,
|
82 |
+
llama: Any,
|
83 |
+
# messages: List[Any],
|
84 |
+
prompt: Union[str, List[int]],
|
85 |
+
image_data_uris: Optional[List[Any]] = None,
|
86 |
+
image_token: str = None,
|
87 |
+
functions: Optional[List[Any]] = None,
|
88 |
+
function_call: Optional[Any] = None,
|
89 |
+
tools: Optional[List[Any]] = None,
|
90 |
+
tool_choice: Optional[Any] = None,
|
91 |
+
temperature: float = 0.2,
|
92 |
+
top_p: float = 0.95,
|
93 |
+
top_k: int = 40,
|
94 |
+
min_p: float = 0.05,
|
95 |
+
typical_p: float = 1.0,
|
96 |
+
stream: bool = False,
|
97 |
+
stop: Optional[Union[str, List[str]]] = [],
|
98 |
+
response_format: Optional[
|
99 |
+
Any
|
100 |
+
] = None,
|
101 |
+
max_tokens: Optional[int] = None,
|
102 |
+
presence_penalty: float = 0.0,
|
103 |
+
frequency_penalty: float = 0.0,
|
104 |
+
repeat_penalty: float = 1.1,
|
105 |
+
tfs_z: float = 1.0,
|
106 |
+
mirostat_mode: int = 0,
|
107 |
+
mirostat_tau: float = 5.0,
|
108 |
+
mirostat_eta: float = 0.1,
|
109 |
+
model: Optional[str] = None,
|
110 |
+
logits_processor: Optional[Any] = None,
|
111 |
+
grammar: Optional[Any] = None,
|
112 |
+
**kwargs, # type: ignore
|
113 |
+
):
|
114 |
+
from llama_cpp.llama_chat_format import (
|
115 |
+
ctypes,
|
116 |
+
suppress_stdout_stderr,
|
117 |
+
)
|
118 |
+
assert (
|
119 |
+
llama.context_params.logits_all is True
|
120 |
+
) # BUG: logits_all=True is required for llava
|
121 |
+
assert self.clip_ctx is not None
|
122 |
+
# ! split prompt into different parts
|
123 |
+
assert image_token is not None
|
124 |
+
prompt_parts = prompt.split(image_token)
|
125 |
+
# assert len(prompt_parts)
|
126 |
+
assert len(prompt_parts) == len(image_data_uris) + 1, f'invalid {len(prompt_parts)=} != {len(image_data_uris)=}'
|
127 |
+
llama.reset()
|
128 |
+
prefix = prompt_parts[0]
|
129 |
+
remaining_texts = prompt_parts[1:]
|
130 |
+
llama.reset()
|
131 |
+
llama.eval(llama.tokenize(prefix.encode("utf8"), add_bos=True))
|
132 |
+
for index, (image_uri, prompt_p) in enumerate(zip(image_data_uris, remaining_texts)):
|
133 |
+
image_bytes = self.load_image(image_uri)
|
134 |
+
import array
|
135 |
+
data_array = array.array("B", image_bytes)
|
136 |
+
c_ubyte_ptr = (
|
137 |
+
ctypes.c_ubyte * len(data_array)
|
138 |
+
).from_buffer(data_array)
|
139 |
+
with suppress_stdout_stderr(disable=self.verbose):
|
140 |
+
embed = (
|
141 |
+
self._llava_cpp.llava_image_embed_make_with_bytes(
|
142 |
+
self.clip_ctx,
|
143 |
+
llama.context_params.n_threads,
|
144 |
+
c_ubyte_ptr,
|
145 |
+
len(image_bytes),
|
146 |
+
)
|
147 |
+
)
|
148 |
+
try:
|
149 |
+
n_past = ctypes.c_int(llama.n_tokens)
|
150 |
+
n_past_p = ctypes.pointer(n_past)
|
151 |
+
with suppress_stdout_stderr(disable=self.verbose):
|
152 |
+
self._llava_cpp.llava_eval_image_embed(
|
153 |
+
llama.ctx,
|
154 |
+
embed,
|
155 |
+
llama.n_batch,
|
156 |
+
n_past_p,
|
157 |
+
)
|
158 |
+
assert llama.n_ctx() >= n_past.value
|
159 |
+
llama.n_tokens = n_past.value
|
160 |
+
finally:
|
161 |
+
with suppress_stdout_stderr(disable=self.verbose):
|
162 |
+
self._llava_cpp.llava_image_embed_free(embed)
|
163 |
+
|
164 |
+
llama.eval(llama.tokenize(prompt_p.encode("utf8"), add_bos=False))
|
165 |
+
assert llama.n_ctx() >= llama.n_tokens
|
166 |
+
|
167 |
+
prompt = llama.input_ids[: llama.n_tokens].tolist()
|
168 |
+
# from llava-1.5
|
169 |
+
return llama.create_completion(
|
170 |
+
prompt=prompt,
|
171 |
+
temperature=temperature,
|
172 |
+
top_p=top_p,
|
173 |
+
top_k=top_k,
|
174 |
+
min_p=min_p,
|
175 |
+
typical_p=typical_p,
|
176 |
+
stream=stream,
|
177 |
+
stop=stop,
|
178 |
+
max_tokens=max_tokens,
|
179 |
+
presence_penalty=presence_penalty,
|
180 |
+
frequency_penalty=frequency_penalty,
|
181 |
+
repeat_penalty=repeat_penalty,
|
182 |
+
tfs_z=tfs_z,
|
183 |
+
mirostat_mode=mirostat_mode,
|
184 |
+
mirostat_tau=mirostat_tau,
|
185 |
+
mirostat_eta=mirostat_eta,
|
186 |
+
model=model,
|
187 |
+
logits_processor=logits_processor,
|
188 |
+
grammar=grammar,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
class LlavaLlamaCppEngine(LlamaCppEngine):
|
194 |
+
"""
|
195 |
+
Still in development, expect BUGS
|
196 |
+
|
197 |
+
ERROR: could not know why
|
198 |
+
objc[61055]: Class GGMLMetalClass is implemented in both miniconda3/envs/native/lib/python3.12/site-packages/llama_cpp/libllama.dylib (0x12cb40290) and miniconda3/envs/native/lib/python3.12/site-packages/llama_cpp/libllava.dylib (0x12d9c8290). One of the two will be used. Which one is undefined.
|
199 |
+
|
200 |
+
"""
|
201 |
+
@property
|
202 |
+
def image_token(self):
|
203 |
+
return IMAGE_TOKEN
|
204 |
+
|
205 |
+
def get_multimodal_tokens(self, full_prompt, image_paths=None):
|
206 |
+
num_tokens = len(self.tokenizer.encode(full_prompt))
|
207 |
+
for image_path in image_paths:
|
208 |
+
num_tokens += IMAGE_TOKEN_LENGTH * MAX_PACHES
|
209 |
+
return num_tokens
|
210 |
+
|
211 |
+
def load_model(self):
|
212 |
+
# from transformers import AutoTokenizer, AutoModelForCausalLM
|
213 |
+
from llama_cpp import Llama
|
214 |
+
from llama_cpp.llama_chat_format import Llava15ChatHandler
|
215 |
+
model_dir = os.path.dirname(MODEL_PATH)
|
216 |
+
self.chat_handler = Llava15ChatHandler(clip_model_path=os.path.join(model_dir, "mmproj.bin"))
|
217 |
+
|
218 |
+
self.chat_handler.__call__ = types.MethodType(llava_15_chat_handler_call, self.chat_handler)
|
219 |
+
|
220 |
+
self.model_path = MODEL_PATH
|
221 |
+
self._model = Llama(
|
222 |
+
model_path=self.model_path,
|
223 |
+
n_gpu_layers=N_GPU_LAYERS, # Uncomment to use GPU acceleration
|
224 |
+
# seed=1337, # Uncomment to set a specific seed
|
225 |
+
chat_handler=self.chat_handler,
|
226 |
+
n_ctx=N_CTX, # Uncomment to increase the context window
|
227 |
+
logits_all=True, # needed to make llava work
|
228 |
+
)
|
229 |
+
self._tokenizer = self._model
|
230 |
+
self._model.encode = types.MethodType(encode_tokenize, self._model)
|
231 |
+
print(f'Load model: {self.model_path=} | {N_GPU_LAYERS=} | {N_CTX=}')
|
232 |
+
|
233 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
234 |
+
image_paths = kwargs.get("image_paths", [])
|
235 |
+
|
236 |
+
image_data_uris = [
|
237 |
+
image_to_base64_data_uri(ip)
|
238 |
+
for ip in image_paths
|
239 |
+
]
|
240 |
+
|
241 |
+
stop_strings = list(stop_strings) if stop_strings is not None else []
|
242 |
+
stop_strings = list(set(stop_strings + ["</s>", "<|im_end|>"]))
|
243 |
+
# generator = self._model(
|
244 |
+
generator = self.chat_handler(
|
245 |
+
prompt=prompt,
|
246 |
+
image_data_uris=image_data_uris,
|
247 |
+
image_token=self.image_token,
|
248 |
+
max_tokens=max_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
|
249 |
+
temperature=temperature,
|
250 |
+
stop=stop_strings, # Stop generating just before the model would generate a new question
|
251 |
+
stream=True,
|
252 |
+
)
|
253 |
+
response = ""
|
254 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
255 |
+
for g in generator:
|
256 |
+
response += g['choices'][0]['text']
|
257 |
+
yield response, num_tokens
|
258 |
+
|
259 |
+
if response is not None and len(response) > 0:
|
260 |
+
num_tokens = len(self.tokenizer.encode(prompt + response))
|
261 |
+
yield response, num_tokens
|
262 |
+
|
263 |
+
|
264 |
+
"""
|
265 |
+
|
266 |
+
export MODEL_PATH
|
267 |
+
BACKEND=llama_cpp
|
268 |
+
MODEL_PATH=/Users/nguyenxuanphi/Desktop/projects/cache/seallms/SeaLLMs/SeaLLM-7B-v2-gguf/seallm-v2.chatml.Q4_K_M.gguf
|
269 |
+
N_CTX=4096
|
270 |
+
python app.py
|
271 |
+
|
272 |
+
|
273 |
+
export BACKEND=llava_llama_cpp
|
274 |
+
export MODEL_PATH=/Users/nguyenxuanphi/Desktop/projects/cache/llava/llava-1.5/ggml-model-q4_k.gguf
|
275 |
+
export N_CTX=4096
|
276 |
+
export IMAGE_TOKEN="<image>"
|
277 |
+
python app.py
|
278 |
+
|
279 |
+
|
280 |
+
"""
|
multipurpose_chatbot/engines/mlx_engine.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import mlx.core as mx
|
4 |
+
import mlx.nn as nn
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
7 |
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
8 |
+
import time
|
9 |
+
from mlx_lm import load, generate
|
10 |
+
from mlx_lm.utils import generate_step
|
11 |
+
|
12 |
+
from .base_engine import BaseEngine
|
13 |
+
|
14 |
+
from ..configs import (
|
15 |
+
MODEL_PATH,
|
16 |
+
)
|
17 |
+
|
18 |
+
def generate_string(
|
19 |
+
model: nn.Module,
|
20 |
+
tokenizer: PreTrainedTokenizer,
|
21 |
+
prompt: str,
|
22 |
+
temp: float = 0.0,
|
23 |
+
max_tokens: int = 100,
|
24 |
+
verbose: bool = False,
|
25 |
+
formatter: Callable = None,
|
26 |
+
repetition_penalty: Optional[float] = None,
|
27 |
+
repetition_context_size: Optional[int] = None,
|
28 |
+
stop_strings: Optional[Tuple[str]] = None
|
29 |
+
):
|
30 |
+
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
31 |
+
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
|
32 |
+
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
|
33 |
+
|
34 |
+
tic = time.perf_counter()
|
35 |
+
tokens = []
|
36 |
+
skip = 0
|
37 |
+
REPLACEMENT_CHAR = "\ufffd"
|
38 |
+
|
39 |
+
for (token, prob), n in zip(
|
40 |
+
generate_step(
|
41 |
+
prompt_tokens,
|
42 |
+
model,
|
43 |
+
temp,
|
44 |
+
repetition_penalty,
|
45 |
+
repetition_context_size,
|
46 |
+
),
|
47 |
+
range(max_tokens),
|
48 |
+
):
|
49 |
+
if token == tokenizer.eos_token_id:
|
50 |
+
break
|
51 |
+
if n == 0:
|
52 |
+
prompt_time = time.perf_counter() - tic
|
53 |
+
tic = time.perf_counter()
|
54 |
+
tokens.append(token.item())
|
55 |
+
if stop_strings is not None:
|
56 |
+
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
57 |
+
if token_string.strip().endswith(stop_strings):
|
58 |
+
break
|
59 |
+
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
60 |
+
return token_string
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def generate_yield_string(
|
65 |
+
model: nn.Module,
|
66 |
+
tokenizer: PreTrainedTokenizer,
|
67 |
+
prompt: str,
|
68 |
+
temp: float = 0.0,
|
69 |
+
max_tokens: int = 100,
|
70 |
+
verbose: bool = False,
|
71 |
+
formatter: Callable = None,
|
72 |
+
repetition_penalty: Optional[float] = None,
|
73 |
+
repetition_context_size: Optional[int] = None,
|
74 |
+
stop_strings: Optional[Tuple[str]] = None
|
75 |
+
):
|
76 |
+
"""
|
77 |
+
Generate text from the model.
|
78 |
+
Args:
|
79 |
+
model (nn.Module): The language model.
|
80 |
+
tokenizer (PreTrainedTokenizer): The tokenizer.
|
81 |
+
prompt (str): The string prompt.
|
82 |
+
temp (float): The temperature for sampling (default 0).
|
83 |
+
max_tokens (int): The maximum number of tokens (default 100).
|
84 |
+
verbose (bool): If ``True``, print tokens and timing information
|
85 |
+
(default ``False``).
|
86 |
+
formatter (Optional[Callable]): A function which takes a token and a
|
87 |
+
probability and displays it.
|
88 |
+
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
89 |
+
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
90 |
+
"""
|
91 |
+
if verbose:
|
92 |
+
print("=" * 10)
|
93 |
+
print("Prompt:", prompt)
|
94 |
+
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
|
95 |
+
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
|
96 |
+
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
97 |
+
tic = time.perf_counter()
|
98 |
+
tokens = []
|
99 |
+
skip = 0
|
100 |
+
REPLACEMENT_CHAR = "\ufffd"
|
101 |
+
for (token, prob), n in zip(
|
102 |
+
generate_step(
|
103 |
+
prompt_tokens,
|
104 |
+
model,
|
105 |
+
temp,
|
106 |
+
repetition_penalty,
|
107 |
+
repetition_context_size,
|
108 |
+
),
|
109 |
+
range(max_tokens),
|
110 |
+
):
|
111 |
+
if token == tokenizer.eos_token_id:
|
112 |
+
break
|
113 |
+
# if n == 0:
|
114 |
+
# prompt_time = time.perf_counter() - tic
|
115 |
+
# tic = time.perf_counter()
|
116 |
+
tokens.append(token.item())
|
117 |
+
# if verbose:
|
118 |
+
# s = tokenizer.decode(tokens)
|
119 |
+
# if formatter:
|
120 |
+
# formatter(s[skip:], prob.item())
|
121 |
+
# skip = len(s)
|
122 |
+
# elif REPLACEMENT_CHAR not in s:
|
123 |
+
# print(s[skip:], end="", flush=True)
|
124 |
+
# skip = len(s)
|
125 |
+
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
126 |
+
yield token_string
|
127 |
+
if stop_strings is not None and token_string.strip().endswith(stop_strings):
|
128 |
+
break
|
129 |
+
|
130 |
+
# token_count = len(tokens)
|
131 |
+
# token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
132 |
+
|
133 |
+
# if verbose:
|
134 |
+
# print(token_string[skip:], flush=True)
|
135 |
+
# gen_time = time.perf_counter() - tic
|
136 |
+
# print("=" * 10)
|
137 |
+
# if token_count == 0:
|
138 |
+
# print("No tokens generated for this prompt")
|
139 |
+
# return
|
140 |
+
# prompt_tps = prompt_tokens.size / prompt_time
|
141 |
+
# gen_tps = (token_count - 1) / gen_time
|
142 |
+
# print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
143 |
+
# print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
144 |
+
|
145 |
+
# return token_string
|
146 |
+
|
147 |
+
|
148 |
+
class MlxEngine(BaseEngine):
|
149 |
+
|
150 |
+
def __init__(self, **kwargs) -> None:
|
151 |
+
super().__init__(**kwargs)
|
152 |
+
self._model = None
|
153 |
+
self._tokenizer = None
|
154 |
+
|
155 |
+
@property
|
156 |
+
def tokenizer(self) -> PreTrainedTokenizer:
|
157 |
+
return self._tokenizer
|
158 |
+
|
159 |
+
def load_model(self, ):
|
160 |
+
model_path = MODEL_PATH
|
161 |
+
self._model, self._tokenizer = load(model_path)
|
162 |
+
self.model_path = model_path
|
163 |
+
print(f'Load MLX model from {model_path}')
|
164 |
+
|
165 |
+
|
166 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
167 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
168 |
+
response = None
|
169 |
+
for response in generate_yield_string(
|
170 |
+
self._model, self._tokenizer,
|
171 |
+
prompt, temp=temperature, max_tokens=max_tokens,
|
172 |
+
repetition_penalty=kwargs.get("repetition_penalty", None),
|
173 |
+
stop_strings=stop_strings,
|
174 |
+
):
|
175 |
+
yield response, num_tokens
|
176 |
+
if response is not None:
|
177 |
+
full_text = prompt + response
|
178 |
+
num_tokens = len(self.tokenizer.encode(full_text))
|
179 |
+
yield response, num_tokens
|
180 |
+
|
181 |
+
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
182 |
+
"""
|
183 |
+
! MLX does not support
|
184 |
+
"""
|
185 |
+
responses = [
|
186 |
+
generate_string(
|
187 |
+
self._model, self._tokenizer,
|
188 |
+
s, temp=temperature, max_tokens=max_tokens,
|
189 |
+
repetition_penalty=kwargs.get("repetition_penalty", None),
|
190 |
+
stop_strings=stop_strings,
|
191 |
+
)
|
192 |
+
for s in prompts
|
193 |
+
]
|
194 |
+
return responses
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
multipurpose_chatbot/engines/transformers_engine.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
from typing import Any, Iterator
|
8 |
+
from typing import Iterator, List, Optional, Tuple
|
9 |
+
import filelock
|
10 |
+
import glob
|
11 |
+
import json
|
12 |
+
import time
|
13 |
+
from gradio.routes import Request
|
14 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
15 |
+
from gradio.helpers import special_args
|
16 |
+
import anyio
|
17 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
18 |
+
|
19 |
+
from gradio_client.documentation import document, set_documentation_group
|
20 |
+
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
import types
|
25 |
+
|
26 |
+
from gradio.components import Button
|
27 |
+
from gradio.events import Dependency, EventListenerMethod
|
28 |
+
|
29 |
+
from .base_engine import BaseEngine
|
30 |
+
|
31 |
+
# ! Remember to use static cache
|
32 |
+
|
33 |
+
from transformers import (
|
34 |
+
GenerationConfig,
|
35 |
+
GenerationMixin,
|
36 |
+
LogitsProcessorList,
|
37 |
+
StoppingCriteriaList,
|
38 |
+
DisjunctiveConstraint,
|
39 |
+
BeamSearchScorer,
|
40 |
+
PhrasalConstraint,
|
41 |
+
ConstrainedBeamSearchScorer,
|
42 |
+
PreTrainedModel,
|
43 |
+
)
|
44 |
+
import numpy as np
|
45 |
+
import random
|
46 |
+
import warnings
|
47 |
+
import inspect
|
48 |
+
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
49 |
+
import torch
|
50 |
+
from typing import Callable, List, Optional, Union
|
51 |
+
from torch import nn
|
52 |
+
import torch.distributed as dist
|
53 |
+
import copy
|
54 |
+
|
55 |
+
from ..configs import (
|
56 |
+
MODEL_PATH,
|
57 |
+
DTYPE,
|
58 |
+
DEVICE,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def setup_seed(seed):
|
63 |
+
if seed == -1:
|
64 |
+
return
|
65 |
+
torch.manual_seed(seed)
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
torch.cuda.manual_seed_all(seed)
|
68 |
+
np.random.seed(seed)
|
69 |
+
random.seed(seed)
|
70 |
+
torch.backends.cudnn.deterministic = True
|
71 |
+
|
72 |
+
|
73 |
+
class NewGenerationMixin(GenerationMixin):
|
74 |
+
"""
|
75 |
+
Allow generator sampling
|
76 |
+
|
77 |
+
"""
|
78 |
+
|
79 |
+
# ! Copy from transformers.generation.utils -> GenerationMixin
|
80 |
+
# Change sample function to sample_stream
|
81 |
+
@torch.no_grad()
|
82 |
+
def sample_stream(
|
83 |
+
self,
|
84 |
+
input_ids: torch.LongTensor,
|
85 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
86 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
87 |
+
logits_warper: Optional[LogitsProcessorList] = None,
|
88 |
+
max_length: Optional[int] = None,
|
89 |
+
pad_token_id: Optional[int] = None,
|
90 |
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
91 |
+
output_attentions: Optional[bool] = None,
|
92 |
+
output_hidden_states: Optional[bool] = None,
|
93 |
+
output_scores: Optional[bool] = None,
|
94 |
+
output_logits: Optional[bool] = None,
|
95 |
+
return_dict_in_generate: Optional[bool] = None,
|
96 |
+
synced_gpus: bool = False,
|
97 |
+
streamer: Optional["BaseStreamer"] = None,
|
98 |
+
**model_kwargs,
|
99 |
+
):
|
100 |
+
r"""
|
101 |
+
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
102 |
+
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
103 |
+
|
104 |
+
<Tip warning={true}>
|
105 |
+
|
106 |
+
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
|
107 |
+
For an overview of generation strategies and code examples, check the [following
|
108 |
+
guide](../generation_strategies).
|
109 |
+
|
110 |
+
</Tip>
|
111 |
+
|
112 |
+
Parameters:
|
113 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
114 |
+
The sequence used as a prompt for the generation.
|
115 |
+
logits_processor (`LogitsProcessorList`, *optional*):
|
116 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
117 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
118 |
+
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
119 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
120 |
+
used to tell if the generation loop should stop.
|
121 |
+
logits_warper (`LogitsProcessorList`, *optional*):
|
122 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
123 |
+
to warp the prediction score distribution of the language modeling head applied before multinomial
|
124 |
+
sampling at each generation step.
|
125 |
+
max_length (`int`, *optional*, defaults to 20):
|
126 |
+
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
|
127 |
+
tokens. The maximum length of the sequence to be generated.
|
128 |
+
pad_token_id (`int`, *optional*):
|
129 |
+
The id of the *padding* token.
|
130 |
+
eos_token_id (`Union[int, List[int]]`, *optional*):
|
131 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
132 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
133 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
134 |
+
returned tensors for more details.
|
135 |
+
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
136 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
137 |
+
for more details.
|
138 |
+
output_scores (`bool`, *optional*, defaults to `False`):
|
139 |
+
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
140 |
+
output_logits (`bool`, *optional*, defaults to `False`):
|
141 |
+
Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for
|
142 |
+
more details.
|
143 |
+
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
144 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
145 |
+
synced_gpus (`bool`, *optional*, defaults to `False`):
|
146 |
+
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
147 |
+
streamer (`BaseStreamer`, *optional*):
|
148 |
+
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
149 |
+
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
150 |
+
model_kwargs:
|
151 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
152 |
+
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
153 |
+
|
154 |
+
Return:
|
155 |
+
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
|
156 |
+
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
157 |
+
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
158 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
159 |
+
`model.config.is_encoder_decoder=True`.
|
160 |
+
|
161 |
+
Examples:
|
162 |
+
|
163 |
+
```python
|
164 |
+
>>> from transformers import (
|
165 |
+
... AutoTokenizer,
|
166 |
+
... AutoModelForCausalLM,
|
167 |
+
... LogitsProcessorList,
|
168 |
+
... MinLengthLogitsProcessor,
|
169 |
+
... TopKLogitsWarper,
|
170 |
+
... TemperatureLogitsWarper,
|
171 |
+
... StoppingCriteriaList,
|
172 |
+
... MaxLengthCriteria,
|
173 |
+
... )
|
174 |
+
>>> import torch
|
175 |
+
|
176 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
177 |
+
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
178 |
+
|
179 |
+
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
180 |
+
>>> model.config.pad_token_id = model.config.eos_token_id
|
181 |
+
>>> model.generation_config.pad_token_id = model.config.eos_token_id
|
182 |
+
|
183 |
+
>>> input_prompt = "Today is a beautiful day, and"
|
184 |
+
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
185 |
+
|
186 |
+
>>> # instantiate logits processors
|
187 |
+
>>> logits_processor = LogitsProcessorList(
|
188 |
+
... [
|
189 |
+
... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
|
190 |
+
... ]
|
191 |
+
... )
|
192 |
+
>>> # instantiate logits processors
|
193 |
+
>>> logits_warper = LogitsProcessorList(
|
194 |
+
... [
|
195 |
+
... TopKLogitsWarper(50),
|
196 |
+
... TemperatureLogitsWarper(0.7),
|
197 |
+
... ]
|
198 |
+
... )
|
199 |
+
|
200 |
+
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
201 |
+
|
202 |
+
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
|
203 |
+
>>> outputs = model.sample(
|
204 |
+
... input_ids,
|
205 |
+
... logits_processor=logits_processor,
|
206 |
+
... logits_warper=logits_warper,
|
207 |
+
... stopping_criteria=stopping_criteria,
|
208 |
+
... )
|
209 |
+
|
210 |
+
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
211 |
+
['Today is a beautiful day, and we must do everything possible to make it a day of celebration.']
|
212 |
+
```"""
|
213 |
+
# init values
|
214 |
+
from transformers.generation.utils import (
|
215 |
+
validate_stopping_criteria, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
|
216 |
+
)
|
217 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
218 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
219 |
+
if max_length is not None:
|
220 |
+
warnings.warn(
|
221 |
+
"`max_length` is deprecated in this function, use"
|
222 |
+
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
223 |
+
UserWarning,
|
224 |
+
)
|
225 |
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
226 |
+
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
227 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
228 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
229 |
+
if isinstance(eos_token_id, int):
|
230 |
+
eos_token_id = [eos_token_id]
|
231 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
232 |
+
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
233 |
+
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
234 |
+
output_attentions = (
|
235 |
+
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
236 |
+
)
|
237 |
+
output_hidden_states = (
|
238 |
+
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
239 |
+
)
|
240 |
+
return_dict_in_generate = (
|
241 |
+
return_dict_in_generate
|
242 |
+
if return_dict_in_generate is not None
|
243 |
+
else self.generation_config.return_dict_in_generate
|
244 |
+
)
|
245 |
+
|
246 |
+
# init attention / hidden states / scores tuples
|
247 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
248 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
249 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
250 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
251 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
252 |
+
|
253 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
254 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
255 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
256 |
+
encoder_hidden_states = (
|
257 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
258 |
+
)
|
259 |
+
# keep track of which sequences are already finished
|
260 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
261 |
+
|
262 |
+
this_peer_finished = False # used by synced_gpus only
|
263 |
+
# auto-regressive generation
|
264 |
+
while True:
|
265 |
+
if synced_gpus:
|
266 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
267 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
268 |
+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
269 |
+
# send 0.0 if we finished, 1.0 otherwise
|
270 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
271 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
272 |
+
if this_peer_finished_flag.item() == 0.0:
|
273 |
+
break
|
274 |
+
|
275 |
+
# prepare model inputs
|
276 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
277 |
+
|
278 |
+
# forward pass to get next token
|
279 |
+
outputs = self(
|
280 |
+
**model_inputs,
|
281 |
+
return_dict=True,
|
282 |
+
output_attentions=output_attentions,
|
283 |
+
output_hidden_states=output_hidden_states,
|
284 |
+
)
|
285 |
+
|
286 |
+
if synced_gpus and this_peer_finished:
|
287 |
+
continue # don't waste resources running the code we don't need
|
288 |
+
|
289 |
+
next_token_logits = outputs.logits[:, -1, :]
|
290 |
+
|
291 |
+
# pre-process distribution
|
292 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
293 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
294 |
+
|
295 |
+
# Store scores, attentions and hidden_states when required
|
296 |
+
if return_dict_in_generate:
|
297 |
+
if output_scores:
|
298 |
+
scores += (next_token_scores,)
|
299 |
+
if output_logits:
|
300 |
+
raw_logits += (next_token_logits,)
|
301 |
+
if output_attentions:
|
302 |
+
decoder_attentions += (
|
303 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
304 |
+
)
|
305 |
+
if self.config.is_encoder_decoder:
|
306 |
+
cross_attentions += (outputs.cross_attentions,)
|
307 |
+
|
308 |
+
if output_hidden_states:
|
309 |
+
decoder_hidden_states += (
|
310 |
+
(outputs.decoder_hidden_states,)
|
311 |
+
if self.config.is_encoder_decoder
|
312 |
+
else (outputs.hidden_states,)
|
313 |
+
)
|
314 |
+
|
315 |
+
# sample
|
316 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
317 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
318 |
+
|
319 |
+
# finished sentences should have their next token be a padding token
|
320 |
+
if eos_token_id is not None:
|
321 |
+
if pad_token_id is None:
|
322 |
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
323 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
324 |
+
|
325 |
+
yield next_tokens.cpu()
|
326 |
+
|
327 |
+
# update generated ids, model inputs, and length for next step
|
328 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
329 |
+
if streamer is not None:
|
330 |
+
streamer.put(next_tokens.cpu())
|
331 |
+
|
332 |
+
next_model_inputs = {}
|
333 |
+
if "cache_position" in model_inputs:
|
334 |
+
next_model_inputs['cache_position'] = model_inputs['cache_position']
|
335 |
+
|
336 |
+
try:
|
337 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
338 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
|
339 |
+
# model_inputs=model_inputs
|
340 |
+
model_inputs=next_model_inputs,
|
341 |
+
)
|
342 |
+
except Exception as e:
|
343 |
+
# Older version dont have model_inputs
|
344 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
345 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
|
346 |
+
)
|
347 |
+
|
348 |
+
|
349 |
+
# if eos_token was found in one sentence, set sentence to finished
|
350 |
+
if eos_token_id_tensor is not None:
|
351 |
+
unfinished_sequences = unfinished_sequences.mul(
|
352 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
353 |
+
)
|
354 |
+
|
355 |
+
# stop when each sentence is finished
|
356 |
+
if unfinished_sequences.max() == 0:
|
357 |
+
this_peer_finished = True
|
358 |
+
|
359 |
+
# stop if we exceed the maximum length
|
360 |
+
if stopping_criteria(input_ids, scores):
|
361 |
+
this_peer_finished = True
|
362 |
+
|
363 |
+
if this_peer_finished and not synced_gpus:
|
364 |
+
break
|
365 |
+
|
366 |
+
if streamer is not None:
|
367 |
+
streamer.end()
|
368 |
+
|
369 |
+
# if return_dict_in_generate:
|
370 |
+
# if self.config.is_encoder_decoder:
|
371 |
+
# return GenerateEncoderDecoderOutput(
|
372 |
+
# sequences=input_ids,
|
373 |
+
# scores=scores,
|
374 |
+
# logits=raw_logits,
|
375 |
+
# encoder_attentions=encoder_attentions,
|
376 |
+
# encoder_hidden_states=encoder_hidden_states,
|
377 |
+
# decoder_attentions=decoder_attentions,
|
378 |
+
# cross_attentions=cross_attentions,
|
379 |
+
# decoder_hidden_states=decoder_hidden_states,
|
380 |
+
# past_key_values=model_kwargs.get("past_key_values"),
|
381 |
+
# )
|
382 |
+
# else:
|
383 |
+
# return GenerateDecoderOnlyOutput(
|
384 |
+
# sequences=input_ids,
|
385 |
+
# scores=scores,
|
386 |
+
# logits=raw_logits,
|
387 |
+
# attentions=decoder_attentions,
|
388 |
+
# hidden_states=decoder_hidden_states,
|
389 |
+
# past_key_values=model_kwargs.get("past_key_values"),
|
390 |
+
# )
|
391 |
+
# else:
|
392 |
+
# return input_ids
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
class TransformersEngine(BaseEngine):
|
397 |
+
@property
|
398 |
+
def max_position_embeddings(self) -> int:
|
399 |
+
return self._model.config.max_position_embeddings
|
400 |
+
|
401 |
+
@property
|
402 |
+
def tokenizer(self):
|
403 |
+
return self._tokenizer
|
404 |
+
|
405 |
+
def load_model(self):
|
406 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
407 |
+
import sys
|
408 |
+
# caution: path[0] is reserved for script path (or '' in REPL)
|
409 |
+
# sys.path.append(CODE_PATH)
|
410 |
+
self.model_path = model_path = MODEL_PATH
|
411 |
+
self.torch_dtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16
|
412 |
+
self.device_map = DEVICE
|
413 |
+
print(f'Loading model from {model_path} on {self.device_map} with {self.torch_dtype}')
|
414 |
+
|
415 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
416 |
+
assert self._tokenizer.chat_template is not None and self._tokenizer.chat_template != "", f"{self._tokenizer.chat_template=} not found!"
|
417 |
+
self._model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=self.torch_dtype, device_map=self.device_map, trust_remote_code=True).eval()
|
418 |
+
self._model.sample_old = self._model.sample
|
419 |
+
self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
420 |
+
print(self._model)
|
421 |
+
print(f"{self.max_position_embeddings=}")
|
422 |
+
|
423 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
424 |
+
|
425 |
+
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
426 |
+
with torch.no_grad():
|
427 |
+
inputs = self.tokenizer(prompt, return_tensors='pt')
|
428 |
+
num_tokens = inputs.input_ids.size(1)
|
429 |
+
|
430 |
+
inputs = {k: v.to(self.device_map) for k, v in inputs.items() if v is not None}
|
431 |
+
generator = self._model.generate(
|
432 |
+
**inputs,
|
433 |
+
do_sample=True,
|
434 |
+
temperature=temperature,
|
435 |
+
max_new_tokens=max_tokens,
|
436 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
437 |
+
)
|
438 |
+
|
439 |
+
out_tokens = []
|
440 |
+
response = None
|
441 |
+
for token in generator:
|
442 |
+
out_tokens.append(token.item())
|
443 |
+
response = self.processor.tokenizer.decode(out_tokens)
|
444 |
+
num_tokens += 1
|
445 |
+
# print(f"{num_tokens=}", end='\r')
|
446 |
+
# sys.stdout.flush()
|
447 |
+
yield response, num_tokens
|
448 |
+
|
449 |
+
if response is not None:
|
450 |
+
full_text = prompt + response
|
451 |
+
num_tokens = len(self.tokenizer.encode(full_text))
|
452 |
+
yield response, num_tokens
|
multipurpose_chatbot/engines/vllm_engine.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from typing import Any, Iterator
|
6 |
+
from typing import Iterator, List, Optional, Tuple
|
7 |
+
import filelock
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
from gradio.routes import Request
|
12 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
13 |
+
from gradio.helpers import special_args
|
14 |
+
import anyio
|
15 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
16 |
+
|
17 |
+
from gradio_client.documentation import document, set_documentation_group
|
18 |
+
|
19 |
+
from typing import List, Optional, Union, Dict, Tuple
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
|
23 |
+
from gradio.components import Button
|
24 |
+
from gradio.events import Dependency, EventListenerMethod
|
25 |
+
|
26 |
+
from .base_engine import BaseEngine
|
27 |
+
# @@ environments ================
|
28 |
+
|
29 |
+
from ..configs import (
|
30 |
+
DTYPE,
|
31 |
+
TENSOR_PARALLEL,
|
32 |
+
MODEL_PATH,
|
33 |
+
QUANTIZATION,
|
34 |
+
MAX_TOKENS,
|
35 |
+
TEMPERATURE,
|
36 |
+
FREQUENCE_PENALTY,
|
37 |
+
PRESENCE_PENALTY,
|
38 |
+
GPU_MEMORY_UTILIZATION,
|
39 |
+
STREAM_CHECK_MULTIPLE,
|
40 |
+
STREAM_YIELD_MULTIPLE,
|
41 |
+
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
llm = None
|
46 |
+
demo = None
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def vllm_abort(self):
|
51 |
+
sh = self.llm_engine.scheduler
|
52 |
+
for g in (sh.waiting + sh.running + sh.swapped):
|
53 |
+
sh.abort_seq_group(g.request_id)
|
54 |
+
from vllm.sequence import SequenceStatus
|
55 |
+
scheduler = self.llm_engine.scheduler
|
56 |
+
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
57 |
+
for seq_group in state_queue:
|
58 |
+
# if seq_group.request_id == request_id:
|
59 |
+
# Remove the sequence group from the state queue.
|
60 |
+
state_queue.remove(seq_group)
|
61 |
+
for seq in seq_group.seqs:
|
62 |
+
if seq.is_finished():
|
63 |
+
continue
|
64 |
+
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
65 |
+
|
66 |
+
|
67 |
+
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
68 |
+
from vllm.outputs import RequestOutput
|
69 |
+
# Initialize tqdm.
|
70 |
+
if use_tqdm:
|
71 |
+
num_requests = self.llm_engine.get_num_unfinished_requests()
|
72 |
+
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
73 |
+
# Run the engine.
|
74 |
+
outputs: Dict[str, RequestOutput] = {}
|
75 |
+
while self.llm_engine.has_unfinished_requests():
|
76 |
+
step_outputs = self.llm_engine.step()
|
77 |
+
for output in step_outputs:
|
78 |
+
outputs[output.request_id] = output
|
79 |
+
if len(outputs) > 0:
|
80 |
+
yield outputs
|
81 |
+
|
82 |
+
|
83 |
+
def vllm_generate_stream(
|
84 |
+
self: Any,
|
85 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
86 |
+
sampling_params: Optional[Any] = None,
|
87 |
+
prompt_token_ids: Optional[List[List[int]]] = None,
|
88 |
+
use_tqdm: bool = False,
|
89 |
+
) -> Dict[str, Any]:
|
90 |
+
"""Generates the completions for the input prompts.
|
91 |
+
|
92 |
+
NOTE: This class automatically batches the given prompts, considering
|
93 |
+
the memory constraint. For the best performance, put all of your prompts
|
94 |
+
into a single list and pass it to this method.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
prompts: A list of prompts to generate completions for.
|
98 |
+
sampling_params: The sampling parameters for text generation. If
|
99 |
+
None, we use the default sampling parameters.
|
100 |
+
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
101 |
+
use the tokenizer to convert the prompts to token IDs.
|
102 |
+
use_tqdm: Whether to use tqdm to display the progress bar.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
A list of `RequestOutput` objects containing the generated
|
106 |
+
completions in the same order as the input prompts.
|
107 |
+
"""
|
108 |
+
from vllm import LLM, SamplingParams
|
109 |
+
if prompts is None and prompt_token_ids is None:
|
110 |
+
raise ValueError("Either prompts or prompt_token_ids must be "
|
111 |
+
"provided.")
|
112 |
+
if isinstance(prompts, str):
|
113 |
+
# Convert a single prompt to a list.
|
114 |
+
prompts = [prompts]
|
115 |
+
if prompts is not None and prompt_token_ids is not None:
|
116 |
+
if len(prompts) != len(prompt_token_ids):
|
117 |
+
raise ValueError("The lengths of prompts and prompt_token_ids "
|
118 |
+
"must be the same.")
|
119 |
+
if sampling_params is None:
|
120 |
+
# Use default sampling params.
|
121 |
+
sampling_params = SamplingParams()
|
122 |
+
# Add requests to the engine.
|
123 |
+
if prompts is not None:
|
124 |
+
num_requests = len(prompts)
|
125 |
+
else:
|
126 |
+
num_requests = len(prompt_token_ids)
|
127 |
+
for i in range(num_requests):
|
128 |
+
prompt = prompts[i] if prompts is not None else None
|
129 |
+
if prompt_token_ids is None:
|
130 |
+
token_ids = None
|
131 |
+
else:
|
132 |
+
token_ids = prompt_token_ids[i]
|
133 |
+
self._add_request(prompt, sampling_params, token_ids)
|
134 |
+
# return self._run_engine(use_tqdm)
|
135 |
+
yield from _vllm_run_engine(self, use_tqdm)
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class VllmEngine(BaseEngine):
|
140 |
+
def __init__(self, **kwargs) -> None:
|
141 |
+
super().__init__(**kwargs)
|
142 |
+
|
143 |
+
@property
|
144 |
+
def tokenizer(self):
|
145 |
+
return self._model.get_tokenizer()
|
146 |
+
|
147 |
+
def load_model(self, ):
|
148 |
+
import torch
|
149 |
+
try:
|
150 |
+
compute_capability = torch.cuda.get_device_capability()
|
151 |
+
print(f'Torch CUDA compute_capability: {compute_capability}')
|
152 |
+
except Exception as e:
|
153 |
+
print(f'Failed to print compute_capability version: {e}')
|
154 |
+
|
155 |
+
import vllm
|
156 |
+
from vllm import LLM
|
157 |
+
|
158 |
+
print(f'VLLM: {vllm.__version__=}')
|
159 |
+
|
160 |
+
if QUANTIZATION == 'awq':
|
161 |
+
print(F'Load model in int4 quantization')
|
162 |
+
llm = LLM(
|
163 |
+
model=MODEL_PATH,
|
164 |
+
dtype="float16",
|
165 |
+
tensor_parallel_size=TENSOR_PARALLEL,
|
166 |
+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
167 |
+
quantization="awq",
|
168 |
+
max_model_len=MAX_TOKENS
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
llm = LLM(
|
172 |
+
model=MODEL_PATH,
|
173 |
+
dtype=DTYPE,
|
174 |
+
tensor_parallel_size=TENSOR_PARALLEL,
|
175 |
+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
176 |
+
max_model_len=MAX_TOKENS
|
177 |
+
)
|
178 |
+
|
179 |
+
try:
|
180 |
+
print(llm.llm_engine.workers[0].model)
|
181 |
+
except Exception as e:
|
182 |
+
print(f'Cannot print model worker: {e}')
|
183 |
+
|
184 |
+
try:
|
185 |
+
llm.llm_engine.scheduler_config.max_model_len = MAX_TOKENS
|
186 |
+
llm.llm_engine.scheduler_config.max_num_batched_tokens = MAX_TOKENS
|
187 |
+
except Exception as e:
|
188 |
+
print(f'Cannot set parameters: {e}')
|
189 |
+
|
190 |
+
self._model = llm
|
191 |
+
|
192 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
193 |
+
from vllm import SamplingParams
|
194 |
+
# ! must abort previous ones
|
195 |
+
vllm_abort(llm)
|
196 |
+
sampling_params = SamplingParams(
|
197 |
+
temperature=temperature,
|
198 |
+
max_tokens=max_tokens,
|
199 |
+
# frequency_penalty=frequency_penalty,
|
200 |
+
# presence_penalty=presence_penalty,
|
201 |
+
stop=stop_strings,
|
202 |
+
)
|
203 |
+
cur_out = None
|
204 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
205 |
+
for j, gen in enumerate(vllm_generate_stream(llm, prompt, sampling_params)):
|
206 |
+
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
207 |
+
yield cur_out, num_tokens
|
208 |
+
assert len(gen) == 1, f'{gen}'
|
209 |
+
item = next(iter(gen.values()))
|
210 |
+
cur_out = item.outputs[0].text
|
211 |
+
|
212 |
+
if cur_out is not None:
|
213 |
+
full_text = prompt + cur_out
|
214 |
+
num_tokens = len(self.tokenizer.encode(full_text))
|
215 |
+
yield cur_out, num_tokens
|
216 |
+
|
217 |
+
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
218 |
+
"""
|
219 |
+
Only vllm should support this, the other engines is only batch=1 only
|
220 |
+
"""
|
221 |
+
from vllm import SamplingParams
|
222 |
+
# ! must abort previous ones
|
223 |
+
vllm_abort(llm)
|
224 |
+
sampling_params = SamplingParams(
|
225 |
+
temperature=temperature,
|
226 |
+
max_tokens=max_tokens,
|
227 |
+
# frequency_penalty=frequency_penalty,
|
228 |
+
# presence_penalty=presence_penalty,
|
229 |
+
stop=stop_strings,
|
230 |
+
)
|
231 |
+
generated = llm.generate(prompts, sampling_params, use_tqdm=False)
|
232 |
+
responses = [g.outputs[0].text for g in generated]
|
233 |
+
return responses
|
multipurpose_chatbot/globals.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
global MODEL_ENGINE
|
4 |
+
|
5 |
+
from multipurpose_chatbot.engines import load_multipurpose_chatbot_engine
|
6 |
+
from multipurpose_chatbot.demos import get_demo_class
|
7 |
+
|
8 |
+
from .configs import (
|
9 |
+
BACKEND,
|
10 |
+
RAG_EMBED_MODEL_NAME,
|
11 |
+
)
|
12 |
+
|
13 |
+
MODEL_ENGINE = load_multipurpose_chatbot_engine(BACKEND)
|
14 |
+
|
15 |
+
|
16 |
+
RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
|
17 |
+
|
18 |
+
|
19 |
+
def load_embeddings():
|
20 |
+
global RAG_EMBED
|
21 |
+
if RAG_EMBED is None:
|
22 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
23 |
+
print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
|
24 |
+
RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True, "device": "cpu"})
|
25 |
+
else:
|
26 |
+
print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
|
27 |
+
return RAG_EMBED
|
28 |
+
|
29 |
+
|
30 |
+
def get_rag_embeddings():
|
31 |
+
return load_embeddings()
|
32 |
+
|
33 |
+
|
pyproject.toml
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
gradio
|
3 |
+
tiktoken
|
4 |
+
openai
|
5 |
+
transformers
|
6 |
+
langchain
|
7 |
+
langchain-community
|
8 |
+
langchain-core
|
9 |
+
chromadb
|
10 |
+
pypdf
|
11 |
+
docx2txt
|
transformers_requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
transformers
|
vllm_requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
vllm
|