Spaces:
Running
on
Zero
Running
on
Zero
Upload 40 files
Browse files- .gitattributes +1 -0
- LICENSE +201 -0
- app.py +73 -1729
- assets/.DS_Store +0 -0
- assets/attention_all_you_need.pdf +0 -0
- assets/attention_short.pdf +0 -0
- assets/dog_monalisa.jpeg +0 -0
- assets/upload_chat.json +10 -0
- assets/upload_few_shot.json +10 -0
- llama_cpp_requirements.txt +1 -0
- mlx_requirements.txt +2 -0
- multipurpose_chatbot/.DS_Store +0 -0
- multipurpose_chatbot/__init__.py +0 -0
- multipurpose_chatbot/configs.py +140 -0
- multipurpose_chatbot/demos/.DS_Store +0 -0
- multipurpose_chatbot/demos/__init__.py +9 -0
- multipurpose_chatbot/demos/base_demo.py +105 -0
- multipurpose_chatbot/demos/batch_inference.py +0 -0
- multipurpose_chatbot/demos/chat_interface.py +692 -0
- multipurpose_chatbot/demos/multimodal_chat_interface.py +1295 -0
- multipurpose_chatbot/demos/multimodal_preference_interface.py +794 -0
- multipurpose_chatbot/demos/rag_chat_interface.py +638 -0
- multipurpose_chatbot/demos/text_completion.py +199 -0
- multipurpose_chatbot/engines/.DS_Store +0 -0
- multipurpose_chatbot/engines/__init__.py +53 -0
- multipurpose_chatbot/engines/base_engine.py +42 -0
- multipurpose_chatbot/engines/debug_engine.py +49 -0
- multipurpose_chatbot/engines/llama_cpp_engine.py +131 -0
- multipurpose_chatbot/engines/llava_llama_cpp_engine.py +280 -0
- multipurpose_chatbot/engines/mlx_engine.py +202 -0
- multipurpose_chatbot/engines/modeling_sealmm.py +1091 -0
- multipurpose_chatbot/engines/sealmmm_engine.py +269 -0
- multipurpose_chatbot/engines/transformers_engine.py +454 -0
- multipurpose_chatbot/engines/vllm_engine.py +233 -0
- multipurpose_chatbot/globals.py +33 -0
- pyproject.toml +0 -0
- requirements.txt +11 -13
- seallm_app.py +1787 -0
- seammm_2.png +3 -0
- transformers_requirements.txt +1 -0
- vllm_requirements.txt +2 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
seammm_2.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
CHANGED
@@ -3,14 +3,15 @@
|
|
3 |
|
4 |
# Description:
|
5 |
"""
|
6 |
-
|
7 |
"""
|
8 |
|
9 |
|
10 |
import os
|
|
|
11 |
import numpy as np
|
12 |
import argparse
|
13 |
-
import torch
|
14 |
import gradio as gr
|
15 |
from typing import Any, Iterator
|
16 |
from typing import Iterator, List, Optional, Tuple
|
@@ -29,1759 +30,102 @@ from gradio_client.documentation import document, set_documentation_group
|
|
29 |
from typing import List, Optional, Union, Dict, Tuple
|
30 |
from tqdm.auto import tqdm
|
31 |
from huggingface_hub import snapshot_download
|
32 |
-
|
33 |
-
|
34 |
-
# @@ environments ================
|
35 |
-
|
36 |
-
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
37 |
-
|
38 |
-
# List of languages to block
|
39 |
-
BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
|
40 |
-
BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
|
41 |
-
|
42 |
-
# for lang block, wether to block in history too
|
43 |
-
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
|
44 |
-
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
45 |
-
DTYPE = os.environ.get("DTYPE", "bfloat16")
|
46 |
-
|
47 |
-
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
48 |
-
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
49 |
-
LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
|
50 |
-
# ! show model path in the demo page, only for internal
|
51 |
-
DISPLAY_MODEL_PATH = bool(int(os.environ.get("DISPLAY_MODEL_PATH", "1")))
|
52 |
-
|
53 |
-
# ! uploaded model path, will be downloaded to MODEL_PATH
|
54 |
-
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
55 |
-
# ! if model is private, need HF_TOKEN to access the model
|
56 |
-
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
57 |
-
# ! path where the model is downloaded, either on ./ or persistent disc
|
58 |
-
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
59 |
-
|
60 |
-
# ! log path
|
61 |
-
LOG_PATH = os.environ.get("LOG_PATH", "").strip()
|
62 |
-
LOG_FILE = None
|
63 |
-
SAVE_LOGS = LOG_PATH is not None and LOG_PATH != ''
|
64 |
-
if SAVE_LOGS:
|
65 |
-
if os.path.exists(LOG_PATH):
|
66 |
-
print(f'LOG_PATH exist: {LOG_PATH}')
|
67 |
-
else:
|
68 |
-
LOG_DIR = os.path.dirname(LOG_PATH)
|
69 |
-
os.makedirs(LOG_DIR, exist_ok=True)
|
70 |
-
|
71 |
-
# ! get LOG_PATH as aggregated outputs in log
|
72 |
-
GET_LOG_CMD = os.environ.get("GET_LOG_CMD", "").strip()
|
73 |
-
|
74 |
-
print(f'SAVE_LOGS: {SAVE_LOGS} | {LOG_PATH}')
|
75 |
-
# print(f'GET_LOG_CMD: {GET_LOG_CMD}')
|
76 |
-
|
77 |
-
# ! !! Whether to delete the folder, ONLY SET THIS IF YOU WANT TO DELETE SAVED MODEL ON PERSISTENT DISC
|
78 |
-
DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "")
|
79 |
-
IS_DELETE_FOLDER = DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER)
|
80 |
-
print(f'DELETE_FOLDER: {DELETE_FOLDER} | {DOWNLOAD_SNAPSHOT=}')
|
81 |
-
|
82 |
-
# ! list of keywords to disabled as security measures to comply with local regulation
|
83 |
-
KEYWORDS = os.environ.get("KEYWORDS", "").strip()
|
84 |
-
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
|
85 |
-
KEYWORDS = [x.lower() for x in KEYWORDS]
|
86 |
-
|
87 |
-
# bypass
|
88 |
-
BYPASS_USERS = os.environ.get("BYPASS_USERS", "").strip()
|
89 |
-
BYPASS_USERS = BYPASS_USERS.split(";") if len(BYPASS_USERS) > 0 else []
|
90 |
-
|
91 |
-
# gradio config
|
92 |
-
PORT = int(os.environ.get("PORT", "7860"))
|
93 |
-
# how many iterations to yield response
|
94 |
-
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
95 |
-
# how many iterations to perform safety check on response
|
96 |
-
STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
|
97 |
-
|
98 |
-
# whether to enable to popup accept user
|
99 |
-
ENABLE_AGREE_POPUP = bool(int(os.environ.get("ENABLE_AGREE_POPUP", "0")))
|
100 |
-
|
101 |
-
# self explanatory
|
102 |
-
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
103 |
-
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
|
104 |
-
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.1"))
|
105 |
-
PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
|
106 |
-
gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
|
107 |
-
|
108 |
-
# whether to enable quantization, currently not in use
|
109 |
-
QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
|
110 |
-
|
111 |
-
|
112 |
-
# Batch inference file upload
|
113 |
-
ENABLE_BATCH_INFER = bool(int(os.environ.get("ENABLE_BATCH_INFER", "1")))
|
114 |
-
BATCH_INFER_MAX_ITEMS = int(os.environ.get("BATCH_INFER_MAX_ITEMS", "100"))
|
115 |
-
BATCH_INFER_MAX_FILE_SIZE = int(os.environ.get("BATCH_INFER_MAX_FILE_SIZE", "500"))
|
116 |
-
BATCH_INFER_MAX_PROMPT_TOKENS = int(os.environ.get("BATCH_INFER_MAX_PROMPT_TOKENS", "4000"))
|
117 |
-
BATCH_INFER_SAVE_TMP_FILE = os.environ.get("BATCH_INFER_SAVE_TMP_FILE", "./tmp/pred.json")
|
118 |
-
|
119 |
-
#
|
120 |
-
DATA_SET_REPO_PATH = str(os.environ.get("DATA_SET_REPO_PATH", ""))
|
121 |
-
DATA_SET_REPO = None
|
122 |
-
|
123 |
-
"""
|
124 |
-
Internal instructions of how to configure the DEMO
|
125 |
-
|
126 |
-
1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
|
127 |
-
2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
|
128 |
-
3. space config env: `HF_MODEL_NAME=SeaLLMs/seal-13b-chat-a` or the underlining model
|
129 |
-
4. If enable persistent storage: set
|
130 |
-
HF_HOME=/data/.huggingface
|
131 |
-
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
132 |
-
if not:
|
133 |
-
MODEL_PATH=./seal-13b-chat-a
|
134 |
-
|
135 |
-
|
136 |
-
HF_HOME=/data/.huggingface
|
137 |
-
MODEL_PATH=/data/ckpt/seal-13b-chat-a
|
138 |
-
DELETE_FOLDER=/data/
|
139 |
-
|
140 |
-
"""
|
141 |
-
|
142 |
-
# ==============================
|
143 |
-
print(f'DEBUG mode: {DEBUG}')
|
144 |
-
print(f'Torch version: {torch.__version__}')
|
145 |
-
try:
|
146 |
-
print(f'Torch CUDA version: {torch.version.cuda}')
|
147 |
-
except Exception as e:
|
148 |
-
print(f'Failed to print cuda version: {e}')
|
149 |
-
|
150 |
-
try:
|
151 |
-
compute_capability = torch.cuda.get_device_capability()
|
152 |
-
print(f'Torch CUDA compute_capability: {compute_capability}')
|
153 |
-
except Exception as e:
|
154 |
-
print(f'Failed to print compute_capability version: {e}')
|
155 |
-
|
156 |
-
|
157 |
-
# @@ constants ================
|
158 |
-
|
159 |
-
DTYPES = {
|
160 |
-
'float16': torch.float16,
|
161 |
-
'bfloat16': torch.bfloat16
|
162 |
-
}
|
163 |
-
|
164 |
-
llm = None
|
165 |
-
demo = None
|
166 |
-
|
167 |
-
|
168 |
-
BOS_TOKEN = '<s>'
|
169 |
-
EOS_TOKEN = '</s>'
|
170 |
-
|
171 |
-
|
172 |
-
SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
# ######### RAG PREPARE
|
177 |
-
RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
|
178 |
-
|
179 |
-
# RAG_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
180 |
-
RAG_EMBED_MODEL_NAME = "sentence-transformers/LaBSE"
|
181 |
-
|
182 |
-
|
183 |
-
def load_embeddings():
|
184 |
-
global RAG_EMBED
|
185 |
-
if RAG_EMBED is None:
|
186 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
187 |
-
print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
|
188 |
-
RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True, "device": "cpu"})
|
189 |
-
else:
|
190 |
-
print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
|
191 |
-
return RAG_EMBED
|
192 |
-
|
193 |
-
|
194 |
-
def get_rag_embeddings():
|
195 |
-
return load_embeddings()
|
196 |
-
|
197 |
-
_ = get_rag_embeddings()
|
198 |
-
|
199 |
-
RAG_CURRENT_VECTORSTORE = None
|
200 |
-
|
201 |
-
def load_document_split_vectorstore(file_path):
|
202 |
-
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
203 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
204 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
205 |
-
from langchain_community.vectorstores import Chroma, FAISS
|
206 |
-
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
|
207 |
-
# assert RAG_EMBED is not None
|
208 |
-
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
|
209 |
-
if file_path.endswith('.pdf'):
|
210 |
-
loader = PyPDFLoader(file_path)
|
211 |
-
elif file_path.endswith('.docx'):
|
212 |
-
loader = Docx2txtLoader(file_path)
|
213 |
-
elif file_path.endswith('.txt'):
|
214 |
-
loader = TextLoader(file_path)
|
215 |
-
splits = loader.load_and_split(splitter)
|
216 |
-
RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
|
217 |
-
return RAG_CURRENT_VECTORSTORE
|
218 |
-
|
219 |
-
|
220 |
-
def docs_to_rag_context(docs: List[str]):
|
221 |
-
contexts = "\n".join([d.page_content for d in docs])
|
222 |
-
context = f"""Answer the following query exclusively based on the information provided in the document above. \
|
223 |
-
If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
|
224 |
-
###
|
225 |
-
{contexts}
|
226 |
-
###
|
227 |
-
|
228 |
-
|
229 |
-
"""
|
230 |
-
return context
|
231 |
-
|
232 |
-
def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
|
233 |
-
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
234 |
-
doc_context = None
|
235 |
-
if file_input is not None:
|
236 |
-
assert os.path.exists(file_input), f"not found: {file_input}"
|
237 |
-
if file_input == RAG_CURRENT_FILE:
|
238 |
-
# reuse
|
239 |
-
vectorstore = RAG_CURRENT_VECTORSTORE
|
240 |
-
print(f'Reuse vectorstore: {file_input}')
|
241 |
-
else:
|
242 |
-
vectorstore = load_document_split_vectorstore(file_input)
|
243 |
-
print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
|
244 |
-
RAG_CURRENT_FILE = file_input
|
245 |
-
docs = vectorstore.similarity_search(message, k=rag_num_docs)
|
246 |
-
doc_context = docs_to_rag_context(docs)
|
247 |
-
return doc_context
|
248 |
-
|
249 |
-
# ######### RAG PREPARE
|
250 |
-
|
251 |
-
|
252 |
-
# ============ CONSTANT ============
|
253 |
-
# https://github.com/gradio-app/gradio/issues/884
|
254 |
-
MODEL_NAME = "SeaLLM-7B"
|
255 |
-
MODEL_NAME = str(os.environ.get("MODEL_NAME", "SeaLLM-7B"))
|
256 |
-
|
257 |
-
MODEL_TITLE = """
|
258 |
-
<div class="container" style="
|
259 |
-
align-items: center;
|
260 |
-
justify-content: center;
|
261 |
-
display: flex;
|
262 |
-
">
|
263 |
-
<div class="image" >
|
264 |
-
<img src="file/seal_logo.png" style="
|
265 |
-
max-width: 10em;
|
266 |
-
max-height: 5%;
|
267 |
-
height: 3em;
|
268 |
-
width: 3em;
|
269 |
-
float: left;
|
270 |
-
margin-left: auto;
|
271 |
-
">
|
272 |
-
</div>
|
273 |
-
<div class="text" style="
|
274 |
-
padding-left: 20px;
|
275 |
-
padding-top: 1%;
|
276 |
-
float: left;
|
277 |
-
">
|
278 |
-
<h1 style="font-size: xx-large">SeaLLMs - Large Language Models for Southeast Asia</h1>
|
279 |
-
</div>
|
280 |
-
</div>
|
281 |
-
"""
|
282 |
-
|
283 |
-
MODEL_TITLE = """
|
284 |
-
<img src="file/seal_logo.png" style="
|
285 |
-
max-width: 10em;
|
286 |
-
max-height: 5%;
|
287 |
-
height: 3em;
|
288 |
-
width: 3em;
|
289 |
-
">
|
290 |
-
<div class="text" style="
|
291 |
-
loat: left;
|
292 |
-
padding-bottom: 2%;
|
293 |
-
">
|
294 |
-
SeaLLMs - Large Language Models for Southeast Asia
|
295 |
-
</div>
|
296 |
-
"""
|
297 |
-
|
298 |
-
"""
|
299 |
-
Somehow cannot add image here
|
300 |
-
<div class="image" >
|
301 |
-
<img src="file/seal_logo.png" style="
|
302 |
-
max-width: 10em;
|
303 |
-
max-height: 5%;
|
304 |
-
height: 3em;
|
305 |
-
width: 3em;
|
306 |
-
float: left;
|
307 |
-
margin-left: auto;
|
308 |
-
">
|
309 |
-
</div>
|
310 |
-
"""
|
311 |
-
|
312 |
-
MODEL_DESC = f"""
|
313 |
-
<div style='display:flex; gap: 0.25rem; '>
|
314 |
-
<a href='https://github.com/damo-nlp-sg/seallms'><img src='https://img.shields.io/badge/Github-Code-success'></a>
|
315 |
-
<a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-7B'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
316 |
-
<a href='https://huggingface.co/SeaLLMs/SeaLLM-7B-v2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
|
317 |
-
<a href='https://arxiv.org/pdf/2312.00738.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
|
318 |
-
</div>
|
319 |
-
<span style="font-size: larger">
|
320 |
-
<a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">{MODEL_NAME}-v2</a> - a helpful assistant for Southeast Asian Languages 🇬🇧 🇻🇳 🇮🇩 🇹🇭 🇲🇾 🇰🇭 🇱🇦 🇵🇭 🇲🇲.
|
321 |
-
Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">our article</a> for more.
|
322 |
-
</span>
|
323 |
-
<br>
|
324 |
-
<span>
|
325 |
-
<span style="color: red">NOTE: The chatbot may produce false and harmful content and does not have up-to-date knowledge.</span>
|
326 |
-
By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>, which includes
|
327 |
-
not to use our service to generate any harmful, inappropriate or illegal content.
|
328 |
-
The service collects user dialogue data for testing and improvement under
|
329 |
-
<a href="https://creativecommons.org/licenses/by/4.0/">(CC-BY)</a> or similar license. So do not enter any personal information!
|
330 |
-
</span>
|
331 |
-
""".strip()
|
332 |
-
|
333 |
-
|
334 |
-
cite_markdown = """
|
335 |
-
## Citation
|
336 |
-
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
337 |
-
```
|
338 |
-
@article{damonlpsg2023seallm,
|
339 |
-
author = {Xuan-Phi Nguyen*, Wenxuan Zhang*, Xin Li*, Mahani Aljunied*, Zhiqiang Hu, Chenhui Shen^, Yew Ken Chia^, Xingxuan Li, Jianyu Wang, Qingyu Tan, Liying Cheng, Guanzheng Chen, Yue Deng, Sen Yang, Chaoqun Liu, Hang Zhang, Lidong Bing},
|
340 |
-
title = {SeaLLMs - Large Language Models for Southeast Asia},
|
341 |
-
year = 2023,
|
342 |
-
}
|
343 |
-
```
|
344 |
-
"""
|
345 |
-
|
346 |
-
path_markdown = """
|
347 |
-
#### Model path:
|
348 |
-
{model_path}
|
349 |
-
"""
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
# ! ==================================================================
|
354 |
-
|
355 |
-
set_documentation_group("component")
|
356 |
-
|
357 |
-
|
358 |
-
RES_PRINTED = False
|
359 |
-
|
360 |
-
|
361 |
-
@document()
|
362 |
-
class ChatBot(gr.Chatbot):
|
363 |
-
def _postprocess_chat_messages(
|
364 |
-
self, chat_message
|
365 |
-
):
|
366 |
-
x = super()._postprocess_chat_messages(chat_message)
|
367 |
-
# if isinstance(x, str):
|
368 |
-
# x = x.strip().replace("\n", "<br>")
|
369 |
-
return x
|
370 |
-
|
371 |
-
|
372 |
-
from gradio.components import Button
|
373 |
from gradio.events import Dependency, EventListenerMethod
|
374 |
|
375 |
-
|
376 |
-
# this prevent weird behavior
|
377 |
-
def _setup_stop_events(
|
378 |
-
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
379 |
-
) -> None:
|
380 |
-
from gradio.components import State
|
381 |
-
event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
|
382 |
-
if self.stop_btn and self.is_generator:
|
383 |
-
if self.submit_btn:
|
384 |
-
for event_trigger in event_triggers:
|
385 |
-
event_trigger(
|
386 |
-
lambda: (
|
387 |
-
Button(visible=False),
|
388 |
-
Button(visible=True),
|
389 |
-
),
|
390 |
-
None,
|
391 |
-
[self.submit_btn, self.stop_btn],
|
392 |
-
api_name=False,
|
393 |
-
queue=False,
|
394 |
-
)
|
395 |
-
event_to_cancel.then(
|
396 |
-
lambda: (Button(visible=True), Button(visible=False)),
|
397 |
-
None,
|
398 |
-
[self.submit_btn, self.stop_btn],
|
399 |
-
api_name=False,
|
400 |
-
queue=False,
|
401 |
-
)
|
402 |
-
else:
|
403 |
-
for event_trigger in event_triggers:
|
404 |
-
event_trigger(
|
405 |
-
lambda: Button(visible=True),
|
406 |
-
None,
|
407 |
-
[self.stop_btn],
|
408 |
-
api_name=False,
|
409 |
-
queue=False,
|
410 |
-
)
|
411 |
-
event_to_cancel.then(
|
412 |
-
lambda: Button(visible=False),
|
413 |
-
None,
|
414 |
-
[self.stop_btn],
|
415 |
-
api_name=False,
|
416 |
-
queue=False,
|
417 |
-
)
|
418 |
-
self.stop_btn.click(
|
419 |
-
None,
|
420 |
-
None,
|
421 |
-
None,
|
422 |
-
cancels=event_to_cancel,
|
423 |
-
api_name=False,
|
424 |
-
)
|
425 |
-
else:
|
426 |
-
if self.submit_btn:
|
427 |
-
for event_trigger in event_triggers:
|
428 |
-
event_trigger(
|
429 |
-
lambda: Button(interactive=False),
|
430 |
-
None,
|
431 |
-
[self.submit_btn],
|
432 |
-
api_name=False,
|
433 |
-
queue=False,
|
434 |
-
)
|
435 |
-
event_to_cancel.then(
|
436 |
-
lambda: Button(interactive=True),
|
437 |
-
None,
|
438 |
-
[self.submit_btn],
|
439 |
-
api_name=False,
|
440 |
-
queue=False,
|
441 |
-
)
|
442 |
-
# upon clear, cancel the submit event as well
|
443 |
-
if self.clear_btn:
|
444 |
-
self.clear_btn.click(
|
445 |
-
lambda: ([], [], None, Button(interactive=True)),
|
446 |
-
None,
|
447 |
-
[self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
|
448 |
-
queue=False,
|
449 |
-
api_name=False,
|
450 |
-
cancels=event_to_cancel,
|
451 |
-
)
|
452 |
-
|
453 |
-
# TODO: reconfigure clear button as stop and clear button
|
454 |
-
def _setup_events(self) -> None:
|
455 |
-
from gradio.components import State
|
456 |
-
has_on = False
|
457 |
-
try:
|
458 |
-
from gradio.events import Dependency, EventListenerMethod, on
|
459 |
-
has_on = True
|
460 |
-
except ImportError as ie:
|
461 |
-
has_on = False
|
462 |
-
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
463 |
-
|
464 |
-
def update_time(c_time, chatbot_state):
|
465 |
-
# if chatbot_state is empty, register a new conversaion with the current timestamp
|
466 |
-
# assert len(chatbot_state) > 0, f'empty chatbot state'
|
467 |
-
if len(chatbot_state) <= 1:
|
468 |
-
return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
|
469 |
-
# elif len(chatbot_state) == 1:
|
470 |
-
# # assert chatbot_state[-1][-1] is None, f'invalid [[message, None]] , got {chatbot_state}'
|
471 |
-
# return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
|
472 |
-
else:
|
473 |
-
return c_time, chatbot_state
|
474 |
-
|
475 |
-
if has_on:
|
476 |
-
# new version
|
477 |
-
submit_triggers = (
|
478 |
-
[self.textbox.submit, self.submit_btn.click]
|
479 |
-
if self.submit_btn
|
480 |
-
else [self.textbox.submit]
|
481 |
-
)
|
482 |
-
submit_event = (
|
483 |
-
on(
|
484 |
-
submit_triggers,
|
485 |
-
self._clear_and_save_textbox,
|
486 |
-
[self.textbox],
|
487 |
-
[self.textbox, self.saved_input],
|
488 |
-
api_name=False,
|
489 |
-
queue=False,
|
490 |
-
)
|
491 |
-
.then(
|
492 |
-
self._display_input,
|
493 |
-
[self.saved_input, self.chatbot_state],
|
494 |
-
[self.chatbot, self.chatbot_state],
|
495 |
-
api_name=False,
|
496 |
-
queue=False,
|
497 |
-
)
|
498 |
-
.then(
|
499 |
-
update_time,
|
500 |
-
[self.additional_inputs[-1], self.chatbot_state],
|
501 |
-
[self.additional_inputs[-1], self.chatbot_state],
|
502 |
-
api_name=False,
|
503 |
-
queue=False,
|
504 |
-
)
|
505 |
-
.then(
|
506 |
-
submit_fn,
|
507 |
-
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
508 |
-
[self.chatbot, self.chatbot_state],
|
509 |
-
api_name=False,
|
510 |
-
)
|
511 |
-
)
|
512 |
-
self._setup_stop_events(submit_triggers, submit_event)
|
513 |
-
else:
|
514 |
-
raise ValueError(f'Better install new gradio version than 3.44.0')
|
515 |
-
|
516 |
-
if self.retry_btn:
|
517 |
-
retry_event = (
|
518 |
-
self.retry_btn.click(
|
519 |
-
self._delete_prev_fn,
|
520 |
-
[self.chatbot_state],
|
521 |
-
[self.chatbot, self.saved_input, self.chatbot_state],
|
522 |
-
api_name=False,
|
523 |
-
queue=False,
|
524 |
-
)
|
525 |
-
.then(
|
526 |
-
self._display_input,
|
527 |
-
[self.saved_input, self.chatbot_state],
|
528 |
-
[self.chatbot, self.chatbot_state],
|
529 |
-
api_name=False,
|
530 |
-
queue=False,
|
531 |
-
)
|
532 |
-
.then(
|
533 |
-
submit_fn,
|
534 |
-
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
535 |
-
[self.chatbot, self.chatbot_state],
|
536 |
-
api_name=False,
|
537 |
-
)
|
538 |
-
)
|
539 |
-
self._setup_stop_events([self.retry_btn.click], retry_event)
|
540 |
-
|
541 |
-
if self.undo_btn:
|
542 |
-
self.undo_btn.click(
|
543 |
-
self._delete_prev_fn,
|
544 |
-
[self.chatbot_state],
|
545 |
-
[self.chatbot, self.saved_input, self.chatbot_state],
|
546 |
-
api_name=False,
|
547 |
-
queue=False,
|
548 |
-
).then(
|
549 |
-
lambda x: x,
|
550 |
-
[self.saved_input],
|
551 |
-
[self.textbox],
|
552 |
-
api_name=False,
|
553 |
-
queue=False,
|
554 |
-
)
|
555 |
-
|
556 |
-
# Reconfigure clear_btn to stop and clear text box
|
557 |
-
|
558 |
-
|
559 |
-
def _display_input(
|
560 |
-
self, message: str, history: List[List[Union[str, None]]]
|
561 |
-
) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
|
562 |
-
if message is not None and message.strip() != "":
|
563 |
-
history.append([message, None])
|
564 |
-
return history, history
|
565 |
-
|
566 |
-
|
567 |
-
async def _stream_fn(
|
568 |
-
self,
|
569 |
-
message: str,
|
570 |
-
history_with_input,
|
571 |
-
request: Request,
|
572 |
-
*args,
|
573 |
-
) -> AsyncGenerator:
|
574 |
-
history = history_with_input[:-1]
|
575 |
-
inputs, _, _ = special_args(
|
576 |
-
self.fn, inputs=[message, history, *args], request=request
|
577 |
-
)
|
578 |
-
|
579 |
-
if self.is_async:
|
580 |
-
generator = self.fn(*inputs)
|
581 |
-
else:
|
582 |
-
generator = await anyio.to_thread.run_sync(
|
583 |
-
self.fn, *inputs, limiter=self.limiter
|
584 |
-
)
|
585 |
-
generator = SyncToAsyncIterator(generator, self.limiter)
|
586 |
-
try:
|
587 |
-
first_response = await async_iteration(generator)
|
588 |
-
update = history + [[message, first_response]]
|
589 |
-
yield update, update
|
590 |
-
except StopIteration:
|
591 |
-
update = history + [[message, None]]
|
592 |
-
yield update, update
|
593 |
-
except Exception as e:
|
594 |
-
yield history, history
|
595 |
-
raise e
|
596 |
-
|
597 |
-
try:
|
598 |
-
async for response in generator:
|
599 |
-
update = history + [[message, response]]
|
600 |
-
yield update, update
|
601 |
-
except Exception as e:
|
602 |
-
# if "invalid" in str(e):
|
603 |
-
# yield history, history
|
604 |
-
# raise e
|
605 |
-
# else:
|
606 |
-
# raise e
|
607 |
-
yield history, history
|
608 |
-
raise e
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
# replace
|
614 |
-
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
615 |
-
gr.ChatInterface._setup_events = _setup_events
|
616 |
-
gr.ChatInterface._display_input = _display_input
|
617 |
-
gr.ChatInterface._stream_fn = _stream_fn
|
618 |
-
|
619 |
-
|
620 |
-
@document()
|
621 |
-
class CustomTabbedInterface(gr.Blocks):
|
622 |
-
def __init__(
|
623 |
-
self,
|
624 |
-
interface_list: list[gr.Interface],
|
625 |
-
tab_names: Optional[list[str]] = None,
|
626 |
-
title: Optional[str] = None,
|
627 |
-
description: Optional[str] = None,
|
628 |
-
theme: Optional[gr.Theme] = None,
|
629 |
-
analytics_enabled: Optional[bool] = None,
|
630 |
-
css: Optional[str] = None,
|
631 |
-
):
|
632 |
-
"""
|
633 |
-
Parameters:
|
634 |
-
interface_list: a list of interfaces to be rendered in tabs.
|
635 |
-
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
636 |
-
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
637 |
-
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
638 |
-
css: custom css or path to custom css file to apply to entire Blocks
|
639 |
-
Returns:
|
640 |
-
a Gradio Tabbed Interface for the given interfaces
|
641 |
-
"""
|
642 |
-
super().__init__(
|
643 |
-
title=title or "Gradio",
|
644 |
-
theme=theme,
|
645 |
-
analytics_enabled=analytics_enabled,
|
646 |
-
mode="tabbed_interface",
|
647 |
-
css=css,
|
648 |
-
)
|
649 |
-
self.description = description
|
650 |
-
if tab_names is None:
|
651 |
-
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
652 |
-
with self:
|
653 |
-
if title:
|
654 |
-
gr.Markdown(
|
655 |
-
f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
|
656 |
-
)
|
657 |
-
if description:
|
658 |
-
gr.Markdown(description)
|
659 |
-
with gr.Tabs():
|
660 |
-
for interface, tab_name in zip(interface_list, tab_names):
|
661 |
-
with gr.Tab(label=tab_name):
|
662 |
-
interface.render()
|
663 |
-
|
664 |
-
|
665 |
-
def vllm_abort(self):
|
666 |
-
sh = self.llm_engine.scheduler
|
667 |
-
for g in (sh.waiting + sh.running + sh.swapped):
|
668 |
-
sh.abort_seq_group(g.request_id)
|
669 |
-
from vllm.sequence import SequenceStatus
|
670 |
-
scheduler = self.llm_engine.scheduler
|
671 |
-
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
672 |
-
for seq_group in state_queue:
|
673 |
-
# if seq_group.request_id == request_id:
|
674 |
-
# Remove the sequence group from the state queue.
|
675 |
-
state_queue.remove(seq_group)
|
676 |
-
for seq in seq_group.seqs:
|
677 |
-
if seq.is_finished():
|
678 |
-
continue
|
679 |
-
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
680 |
-
|
681 |
-
|
682 |
-
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
683 |
-
from vllm.outputs import RequestOutput
|
684 |
-
# Initialize tqdm.
|
685 |
-
if use_tqdm:
|
686 |
-
num_requests = self.llm_engine.get_num_unfinished_requests()
|
687 |
-
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
688 |
-
# Run the engine.
|
689 |
-
outputs: Dict[str, RequestOutput] = {}
|
690 |
-
while self.llm_engine.has_unfinished_requests():
|
691 |
-
step_outputs = self.llm_engine.step()
|
692 |
-
for output in step_outputs:
|
693 |
-
outputs[output.request_id] = output
|
694 |
-
if len(outputs) > 0:
|
695 |
-
yield outputs
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
def vllm_generate_stream(
|
700 |
-
self: Any,
|
701 |
-
prompts: Optional[Union[str, List[str]]] = None,
|
702 |
-
sampling_params: Optional[Any] = None,
|
703 |
-
prompt_token_ids: Optional[List[List[int]]] = None,
|
704 |
-
use_tqdm: bool = False,
|
705 |
-
) -> Dict[str, Any]:
|
706 |
-
"""Generates the completions for the input prompts.
|
707 |
-
|
708 |
-
NOTE: This class automatically batches the given prompts, considering
|
709 |
-
the memory constraint. For the best performance, put all of your prompts
|
710 |
-
into a single list and pass it to this method.
|
711 |
-
|
712 |
-
Args:
|
713 |
-
prompts: A list of prompts to generate completions for.
|
714 |
-
sampling_params: The sampling parameters for text generation. If
|
715 |
-
None, we use the default sampling parameters.
|
716 |
-
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
717 |
-
use the tokenizer to convert the prompts to token IDs.
|
718 |
-
use_tqdm: Whether to use tqdm to display the progress bar.
|
719 |
-
|
720 |
-
Returns:
|
721 |
-
A list of `RequestOutput` objects containing the generated
|
722 |
-
completions in the same order as the input prompts.
|
723 |
-
"""
|
724 |
-
from vllm import LLM, SamplingParams
|
725 |
-
if prompts is None and prompt_token_ids is None:
|
726 |
-
raise ValueError("Either prompts or prompt_token_ids must be "
|
727 |
-
"provided.")
|
728 |
-
if isinstance(prompts, str):
|
729 |
-
# Convert a single prompt to a list.
|
730 |
-
prompts = [prompts]
|
731 |
-
if prompts is not None and prompt_token_ids is not None:
|
732 |
-
if len(prompts) != len(prompt_token_ids):
|
733 |
-
raise ValueError("The lengths of prompts and prompt_token_ids "
|
734 |
-
"must be the same.")
|
735 |
-
if sampling_params is None:
|
736 |
-
# Use default sampling params.
|
737 |
-
sampling_params = SamplingParams()
|
738 |
-
|
739 |
-
# Add requests to the engine.
|
740 |
-
if prompts is not None:
|
741 |
-
num_requests = len(prompts)
|
742 |
-
else:
|
743 |
-
num_requests = len(prompt_token_ids)
|
744 |
-
for i in range(num_requests):
|
745 |
-
prompt = prompts[i] if prompts is not None else None
|
746 |
-
if prompt_token_ids is None:
|
747 |
-
token_ids = None
|
748 |
-
else:
|
749 |
-
token_ids = prompt_token_ids[i]
|
750 |
-
self._add_request(prompt, sampling_params, token_ids)
|
751 |
-
# return self._run_engine(use_tqdm)
|
752 |
-
yield from _vllm_run_engine(self, use_tqdm)
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
# ! avoid saying
|
757 |
-
# LANG_BLOCK_MESSAGE = """Sorry, the language you have asked is currently not supported. If you have questions in other supported languages, I'll be glad to help. \
|
758 |
-
# Please also consider clearing the chat box for a better experience."""
|
759 |
-
|
760 |
-
# KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help."
|
761 |
-
|
762 |
-
LANG_BLOCK_MESSAGE = """Unsupported language."""
|
763 |
-
|
764 |
-
KEYWORD_BLOCK_MESSAGE = "Invalid request."
|
765 |
-
|
766 |
-
|
767 |
-
def _detect_lang(text):
|
768 |
-
# Disable language that may have safety risk
|
769 |
-
from langdetect import detect as detect_lang
|
770 |
-
dlang = None
|
771 |
-
try:
|
772 |
-
dlang = detect_lang(text)
|
773 |
-
except Exception as e:
|
774 |
-
if "No features in text." in str(e):
|
775 |
-
return "en"
|
776 |
-
else:
|
777 |
-
return "zh"
|
778 |
-
return dlang
|
779 |
-
|
780 |
-
|
781 |
-
def block_lang(
|
782 |
-
message: str,
|
783 |
-
history: List[Tuple[str, str]] = None,
|
784 |
-
) -> str:
|
785 |
-
# relieve history base block
|
786 |
-
if len(BLOCK_LANGS) == 0:
|
787 |
-
return False
|
788 |
-
|
789 |
-
if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
|
790 |
-
return True
|
791 |
-
else:
|
792 |
-
_lang = _detect_lang(message)
|
793 |
-
if _lang in BLOCK_LANGS:
|
794 |
-
print(f'Detect blocked {_lang}: {message}')
|
795 |
-
return True
|
796 |
-
else:
|
797 |
-
return False
|
798 |
-
|
799 |
-
|
800 |
-
def safety_check(text, history=None, ) -> Optional[str]:
|
801 |
-
"""
|
802 |
-
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
803 |
-
This provides an additional security measure to enhance safety and compliance with local regulations.
|
804 |
-
"""
|
805 |
-
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
806 |
-
return KEYWORD_BLOCK_MESSAGE
|
807 |
-
|
808 |
-
if len(BLOCK_LANGS) > 0:
|
809 |
-
if block_lang(text, history):
|
810 |
-
return LANG_BLOCK_MESSAGE
|
811 |
-
|
812 |
-
return None
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
TURN_TEMPLATE = "<|im_start|>{role}\n{content}</s>"
|
817 |
-
TURN_PREFIX = "<|im_start|>{role}\n"
|
818 |
-
|
819 |
-
|
820 |
-
def chatml_chat_convo_format(conversations, add_assistant_prefix: bool, default_system=SYSTEM_PROMPT_1):
|
821 |
-
if conversations[0]['role'] != 'system':
|
822 |
-
conversations = [{"role": "system", "content": default_system}] + conversations
|
823 |
-
text = ''
|
824 |
-
for turn_id, turn in enumerate(conversations):
|
825 |
-
prompt = TURN_TEMPLATE.format(role=turn['role'], content=turn['content'])
|
826 |
-
text += prompt
|
827 |
-
if add_assistant_prefix:
|
828 |
-
prompt = TURN_PREFIX.format(role='assistant')
|
829 |
-
text += prompt
|
830 |
-
return text
|
831 |
-
|
832 |
-
|
833 |
-
def chatml_format(message, history=None, system_prompt=None):
|
834 |
-
conversations = []
|
835 |
-
system_prompt = system_prompt or "You are a helpful assistant."
|
836 |
-
if history is not None and len(history) > 0:
|
837 |
-
for i, (prompt, res) in enumerate(history):
|
838 |
-
conversations.append({"role": "user", "content": prompt.strip()})
|
839 |
-
conversations.append({"role": "assistant", "content": res.strip()})
|
840 |
-
conversations.append({"role": "user", "content": message.strip()})
|
841 |
-
return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
|
842 |
-
|
843 |
-
|
844 |
-
def debug_chat_response_stream_multiturn(message, history):
|
845 |
-
message_safety = safety_check(message, history=history)
|
846 |
-
if message_safety is not None:
|
847 |
-
# yield message_safety
|
848 |
-
raise gr.Error(message_safety)
|
849 |
-
|
850 |
-
message = "This is a debugging message"
|
851 |
-
for i in range(len(message)):
|
852 |
-
time.sleep(0.05)
|
853 |
-
yield message[:i]
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
def chat_response_stream_multiturn(
|
858 |
-
message: str,
|
859 |
-
history: List[Tuple[str, str]],
|
860 |
-
temperature: float,
|
861 |
-
max_tokens: int,
|
862 |
-
frequency_penalty: float,
|
863 |
-
presence_penalty: float,
|
864 |
-
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
865 |
-
current_time: Optional[float] = None,
|
866 |
-
# profile: Optional[gr.OAuthProfile] = None,
|
867 |
-
) -> str:
|
868 |
-
"""
|
869 |
-
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
870 |
-
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
871 |
-
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
872 |
-
gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
873 |
-
gr.Textbox(value=sys_prompt, label='System prompt', lines=8, interactive=False),
|
874 |
-
gr.Number(value=0, label='current_time', visible=False),
|
875 |
-
"""
|
876 |
-
global LOG_FILE, LOG_PATH
|
877 |
-
if DEBUG:
|
878 |
-
yield from debug_chat_response_stream_multiturn(message, history)
|
879 |
-
return
|
880 |
-
from vllm import LLM, SamplingParams
|
881 |
-
"""Build multi turn
|
882 |
-
|
883 |
-
message is incoming prompt
|
884 |
-
history don't have the current messauge
|
885 |
-
"""
|
886 |
-
global llm, RES_PRINTED
|
887 |
-
assert llm is not None
|
888 |
-
assert system_prompt.strip() != '', f'system prompt is empty'
|
889 |
-
# is_by_pass = False if profile is None else profile.username in BYPASS_USERS
|
890 |
-
is_by_pass = False
|
891 |
-
|
892 |
-
tokenizer = llm.get_tokenizer()
|
893 |
-
# force removing all
|
894 |
-
vllm_abort(llm)
|
895 |
-
|
896 |
-
temperature = float(temperature)
|
897 |
-
frequency_penalty = float(frequency_penalty)
|
898 |
-
max_tokens = int(max_tokens)
|
899 |
-
|
900 |
-
message = message.strip()
|
901 |
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
906 |
|
907 |
-
if len(message) == 0:
|
908 |
-
raise gr.Error("The message cannot be empty!")
|
909 |
|
910 |
-
|
911 |
-
if message_safety is not None and not is_by_pass:
|
912 |
-
# yield message_safety
|
913 |
-
raise gr.Error(message_safety)
|
914 |
-
|
915 |
-
# history will be appended with message later on
|
916 |
-
|
917 |
-
full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
|
918 |
-
print(full_prompt)
|
919 |
-
|
920 |
-
if len(tokenizer.encode(full_prompt)) >= 4050:
|
921 |
-
raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
|
922 |
-
|
923 |
-
sampling_params = SamplingParams(
|
924 |
-
temperature=temperature,
|
925 |
-
max_tokens=max_tokens,
|
926 |
-
frequency_penalty=frequency_penalty,
|
927 |
-
presence_penalty=presence_penalty,
|
928 |
-
# stop=['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'],
|
929 |
-
stop=['<s>', '</s>', '<|im_start|>', '<|im_end|>'],
|
930 |
-
)
|
931 |
-
cur_out = None
|
932 |
-
|
933 |
-
for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
|
934 |
-
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
935 |
-
# cur_out = cur_out.replace("\\n", "\n")
|
936 |
-
|
937 |
-
# optionally check safety, and respond
|
938 |
-
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
|
939 |
-
message_safety = safety_check(cur_out, history=None)
|
940 |
-
if message_safety is not None and not is_by_pass:
|
941 |
-
# yield message_safety
|
942 |
-
raise gr.Error(message_safety)
|
943 |
-
# return
|
944 |
-
|
945 |
-
yield cur_out
|
946 |
-
assert len(gen) == 1, f'{gen}'
|
947 |
-
item = next(iter(gen.values()))
|
948 |
-
cur_out = item.outputs[0].text
|
949 |
-
#cur_out = "Our system is under maintenance, will be back soon!"
|
950 |
-
if j >= max_tokens - 2:
|
951 |
-
gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
|
952 |
-
|
953 |
-
# TODO: use current_time to register conversations, accoriding history and cur_out
|
954 |
-
history_str = format_conversation(history + [[message, cur_out]])
|
955 |
-
print(f'@@@@@@@@@@\n{history_str}\n##########\n')
|
956 |
-
|
957 |
-
maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
|
958 |
-
|
959 |
-
if cur_out is not None and "\\n" in cur_out:
|
960 |
-
print(f'double slash-n in cur_out:\n{cur_out}')
|
961 |
-
cur_out = cur_out.replace("\\n", "\n")
|
962 |
-
|
963 |
-
if cur_out is not None:
|
964 |
-
yield cur_out
|
965 |
-
|
966 |
-
message_safety = safety_check(cur_out, history=None)
|
967 |
-
if message_safety is not None and not is_by_pass:
|
968 |
-
# yield message_safety
|
969 |
-
raise gr.Error(message_safety)
|
970 |
-
# return
|
971 |
-
|
972 |
-
|
973 |
-
|
974 |
-
def chat_response_stream_rag_multiturn(
|
975 |
-
message: str,
|
976 |
-
history: List[Tuple[str, str]],
|
977 |
-
file_input: str,
|
978 |
-
temperature: float,
|
979 |
-
max_tokens: int,
|
980 |
-
# frequency_penalty: float,
|
981 |
-
# presence_penalty: float,
|
982 |
-
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
983 |
-
current_time: Optional[float] = None,
|
984 |
-
rag_num_docs: Optional[int] = 3,
|
985 |
-
):
|
986 |
-
message = message.strip()
|
987 |
-
frequency_penalty = FREQUENCE_PENALTY
|
988 |
-
presence_penalty = PRESENCE_PENALTY
|
989 |
-
if len(message) == 0:
|
990 |
-
raise gr.Error("The message cannot be empty!")
|
991 |
-
doc_context = maybe_get_doc_context(message, file_input, rag_num_docs=rag_num_docs)
|
992 |
-
if doc_context is not None:
|
993 |
-
message = f"{doc_context}\n\n{message}"
|
994 |
-
yield from chat_response_stream_multiturn(
|
995 |
-
message, history, temperature, max_tokens, frequency_penalty,
|
996 |
-
presence_penalty, system_prompt, current_time
|
997 |
-
)
|
998 |
-
|
999 |
-
|
1000 |
-
def debug_generate_free_form_stream(message):
|
1001 |
-
output = " This is a debugging message...."
|
1002 |
-
for i in range(len(output)):
|
1003 |
-
time.sleep(0.05)
|
1004 |
-
yield message + output[:i]
|
1005 |
-
|
1006 |
-
|
1007 |
-
def generate_free_form_stream(
|
1008 |
-
message: str,
|
1009 |
-
temperature: float,
|
1010 |
-
max_tokens: int,
|
1011 |
-
frequency_penalty: float,
|
1012 |
-
presence_penalty: float,
|
1013 |
-
stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
|
1014 |
-
current_time: Optional[float] = None,
|
1015 |
-
) -> str:
|
1016 |
-
global LOG_FILE, LOG_PATH
|
1017 |
-
if DEBUG:
|
1018 |
-
yield from debug_generate_free_form_stream(message)
|
1019 |
-
return
|
1020 |
-
from vllm import LLM, SamplingParams
|
1021 |
-
"""Build multi turn
|
1022 |
-
"""
|
1023 |
-
global llm, RES_PRINTED
|
1024 |
-
assert llm is not None
|
1025 |
-
tokenizer = llm.get_tokenizer()
|
1026 |
-
# force removing all
|
1027 |
-
vllm_abort(llm)
|
1028 |
-
|
1029 |
-
temperature = float(temperature)
|
1030 |
-
frequency_penalty = float(frequency_penalty)
|
1031 |
-
max_tokens = int(max_tokens)
|
1032 |
-
|
1033 |
-
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
1034 |
-
stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>']))
|
1035 |
-
|
1036 |
-
sampling_params = SamplingParams(
|
1037 |
-
temperature=temperature,
|
1038 |
-
max_tokens=max_tokens,
|
1039 |
-
frequency_penalty=frequency_penalty,
|
1040 |
-
presence_penalty=presence_penalty,
|
1041 |
-
stop=stop_strings,
|
1042 |
-
# ignore_eos=True,
|
1043 |
-
)
|
1044 |
-
|
1045 |
-
# full_prompt = message
|
1046 |
-
if len(message) == 0:
|
1047 |
-
raise gr.Error("The message cannot be empty!")
|
1048 |
-
|
1049 |
-
message_safety = safety_check(message)
|
1050 |
-
if message_safety is not None:
|
1051 |
-
raise gr.Error(message_safety)
|
1052 |
-
|
1053 |
-
if len(tokenizer.encode(message)) >= 4050:
|
1054 |
-
raise gr.Error(f"Prompt is too long!")
|
1055 |
-
|
1056 |
-
cur_out = None
|
1057 |
-
for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
|
1058 |
-
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
1059 |
-
# optionally check safety, and respond
|
1060 |
-
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
|
1061 |
-
message_safety = safety_check(cur_out, history=None)
|
1062 |
-
if message_safety is not None:
|
1063 |
-
raise gr.Error(message_safety)
|
1064 |
-
yield message + cur_out
|
1065 |
-
assert len(gen) == 1, f'{gen}'
|
1066 |
-
item = next(iter(gen.values()))
|
1067 |
-
cur_out = item.outputs[0].text
|
1068 |
-
#cur_out = "Our system is under maintenance, will be back soon!"
|
1069 |
-
if j >= max_tokens - 2:
|
1070 |
-
gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
|
1071 |
-
|
1072 |
-
if cur_out is not None:
|
1073 |
-
yield message + cur_out
|
1074 |
-
|
1075 |
-
message_safety = safety_check(message + cur_out, history=None)
|
1076 |
-
if message_safety is not None:
|
1077 |
-
raise gr.Error(message_safety)
|
1078 |
-
|
1079 |
-
|
1080 |
-
|
1081 |
-
|
1082 |
-
def maybe_log_conv_file(current_time, history, message, response, **kwargs):
|
1083 |
-
global LOG_FILE
|
1084 |
-
if LOG_FILE is not None:
|
1085 |
-
my_history = history + [[message, response]]
|
1086 |
-
obj = {
|
1087 |
-
'key': str(current_time),
|
1088 |
-
'history': my_history
|
1089 |
-
}
|
1090 |
-
for k, v in kwargs.items():
|
1091 |
-
obj[k] = v
|
1092 |
-
log_ = json.dumps(obj, ensure_ascii=False)
|
1093 |
-
LOG_FILE.write(log_ + "\n")
|
1094 |
-
LOG_FILE.flush()
|
1095 |
-
print(f'Wrote {obj["key"]} to {LOG_PATH}')
|
1096 |
-
|
1097 |
-
|
1098 |
-
def format_conversation(history):
|
1099 |
-
_str = '\n'.join([
|
1100 |
-
(
|
1101 |
-
f'<<<User>>> {h[0]}\n'
|
1102 |
-
f'<<<Asst>>> {h[1]}'
|
1103 |
-
)
|
1104 |
-
for h in history
|
1105 |
-
])
|
1106 |
-
return _str
|
1107 |
-
|
1108 |
-
|
1109 |
-
def aggregate_convos():
|
1110 |
-
from datetime import datetime
|
1111 |
-
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1112 |
-
assert os.path.exists(LOG_PATH), f'{LOG_PATH} not found'
|
1113 |
-
convos = None
|
1114 |
-
irregular_count = 1
|
1115 |
-
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
1116 |
-
convos = {}
|
1117 |
-
for i, l in enumerate(f):
|
1118 |
-
if l:
|
1119 |
-
item = json.loads(l)
|
1120 |
-
key = item['key']
|
1121 |
-
try:
|
1122 |
-
key = float(key)
|
1123 |
-
except Exception as e:
|
1124 |
-
key = -1
|
1125 |
-
if key > 0.0:
|
1126 |
-
item_key = datetime.fromtimestamp(key).strftime("%Y-%m-%d %H:%M:%S")
|
1127 |
-
else:
|
1128 |
-
key = item_key = f'e{irregular_count}'
|
1129 |
-
irregular_count += 1
|
1130 |
-
item['key'] = item_key
|
1131 |
-
convos[key] = item
|
1132 |
-
return convos
|
1133 |
-
|
1134 |
-
def maybe_upload_to_dataset():
|
1135 |
-
from datetime import datetime
|
1136 |
-
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1137 |
-
if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH != "":
|
1138 |
-
convos = aggregate_convos()
|
1139 |
-
AGG_LOG_PATH = LOG_PATH + ".agg.json"
|
1140 |
-
with open(AGG_LOG_PATH, 'w', encoding='utf-8') as fo:
|
1141 |
-
json.dump(convos, fo, indent=4, ensure_ascii=False)
|
1142 |
-
print(f'Saved aggregated json to {AGG_LOG_PATH}')
|
1143 |
-
try:
|
1144 |
-
from huggingface_hub import upload_file
|
1145 |
-
print(f'upload {AGG_LOG_PATH} to {DATA_SET_REPO_PATH}')
|
1146 |
-
upload_file(
|
1147 |
-
path_or_fileobj=AGG_LOG_PATH,
|
1148 |
-
path_in_repo=os.path.basename(AGG_LOG_PATH),
|
1149 |
-
repo_id=DATA_SET_REPO_PATH,
|
1150 |
-
token=HF_TOKEN,
|
1151 |
-
repo_type="dataset",
|
1152 |
-
create_pr=True
|
1153 |
-
)
|
1154 |
-
except Exception as e:
|
1155 |
-
print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
|
1156 |
-
|
1157 |
-
|
1158 |
-
def print_log_file():
|
1159 |
-
global LOG_FILE, LOG_PATH
|
1160 |
-
if SAVE_LOGS and os.path.exists(LOG_PATH):
|
1161 |
-
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
1162 |
-
convos = aggregate_convos()
|
1163 |
-
print(f'Printing log from {LOG_PATH}')
|
1164 |
-
items = list(convos.items())
|
1165 |
-
for k, v in items[-10:]:
|
1166 |
-
history = v.pop('history')
|
1167 |
-
print(f'######--{v}--#####')
|
1168 |
-
_str = format_conversation(history)
|
1169 |
-
print(_str)
|
1170 |
-
maybe_upload_to_dataset()
|
1171 |
-
|
1172 |
-
|
1173 |
-
def debug_chat_response_echo(
|
1174 |
-
message: str,
|
1175 |
-
history: List[Tuple[str, str]],
|
1176 |
-
temperature: float = 0.0,
|
1177 |
-
max_tokens: int = 4096,
|
1178 |
-
frequency_penalty: float = 0.4,
|
1179 |
-
presence_penalty: float = 0.0,
|
1180 |
-
current_time: Optional[float] = None,
|
1181 |
-
system_prompt: str = SYSTEM_PROMPT_1,
|
1182 |
-
) -> str:
|
1183 |
-
global LOG_FILE
|
1184 |
-
import time
|
1185 |
-
time.sleep(0.5)
|
1186 |
-
|
1187 |
-
if message.strip() == GET_LOG_CMD:
|
1188 |
-
print_log_file()
|
1189 |
-
yield "Finish printed log."
|
1190 |
-
return
|
1191 |
-
|
1192 |
-
for i in range(len(message)):
|
1193 |
-
yield f"repeat: {current_time} {message[:i + 1]}"
|
1194 |
-
|
1195 |
-
cur_out = f"repeat: {current_time} {message}"
|
1196 |
-
maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
|
1197 |
-
|
1198 |
-
|
1199 |
-
def check_model_path(model_path) -> str:
|
1200 |
-
assert os.path.exists(model_path), f'{model_path} not found'
|
1201 |
-
ckpt_info = "None"
|
1202 |
-
if os.path.isdir(model_path):
|
1203 |
-
if os.path.exists(f'{model_path}/info.txt'):
|
1204 |
-
with open(f'{model_path}/info.txt', 'r') as f:
|
1205 |
-
ckpt_info = f.read()
|
1206 |
-
print(f'Checkpoint info:\n{ckpt_info}\n-----')
|
1207 |
-
else:
|
1208 |
-
print(f'info.txt not found in {model_path}')
|
1209 |
-
print(f'model path dir: {list(os.listdir(model_path))}')
|
1210 |
-
|
1211 |
-
return ckpt_info
|
1212 |
-
|
1213 |
-
|
1214 |
-
def maybe_delete_folder():
|
1215 |
-
if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT:
|
1216 |
-
import shutil
|
1217 |
-
print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
|
1218 |
-
for filename in os.listdir(DELETE_FOLDER):
|
1219 |
-
file_path = os.path.join(DELETE_FOLDER, filename)
|
1220 |
-
try:
|
1221 |
-
if os.path.isfile(file_path) or os.path.islink(file_path):
|
1222 |
-
os.unlink(file_path)
|
1223 |
-
elif os.path.isdir(file_path):
|
1224 |
-
shutil.rmtree(file_path)
|
1225 |
-
except Exception as e:
|
1226 |
-
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
1227 |
-
|
1228 |
-
|
1229 |
-
AGREE_POP_SCRIPTS = """
|
1230 |
-
async () => {
|
1231 |
-
alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
|
1232 |
-
}
|
1233 |
-
"""
|
1234 |
-
|
1235 |
-
def debug_file_function(
|
1236 |
-
files: Union[str, List[str]],
|
1237 |
-
prompt_mode: str,
|
1238 |
-
temperature: float,
|
1239 |
-
max_tokens: int,
|
1240 |
-
frequency_penalty: float,
|
1241 |
-
presence_penalty: float,
|
1242 |
-
stop_strings: str = "[STOP],<s>,</s>",
|
1243 |
-
current_time: Optional[float] = None,
|
1244 |
-
):
|
1245 |
-
"""This is only for debug purpose"""
|
1246 |
-
files = files if isinstance(files, list) else [files]
|
1247 |
-
print(files)
|
1248 |
-
filenames = [f.name for f in files]
|
1249 |
-
all_items = []
|
1250 |
-
for fname in filenames:
|
1251 |
-
print(f'Reading {fname}')
|
1252 |
-
with open(fname, 'r', encoding='utf-8') as f:
|
1253 |
-
items = json.load(f)
|
1254 |
-
assert isinstance(items, list), f'invalid items from {fname} not list'
|
1255 |
-
all_items.extend(items)
|
1256 |
-
print(all_items)
|
1257 |
-
print(f'{prompt_mode} / {temperature} / {max_tokens}, {frequency_penalty}, {presence_penalty}')
|
1258 |
-
save_path = "./test.json"
|
1259 |
-
with open(save_path, 'w', encoding='utf-8') as f:
|
1260 |
-
json.dump(all_items, f, indent=4, ensure_ascii=False)
|
1261 |
-
|
1262 |
-
for x in all_items:
|
1263 |
-
x['response'] = "Return response"
|
1264 |
-
|
1265 |
-
print_items = all_items[:1]
|
1266 |
-
# print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
|
1267 |
-
return save_path, print_items
|
1268 |
-
|
1269 |
-
|
1270 |
-
def validate_file_item(filename, index, item: Dict[str, str]):
|
1271 |
-
"""
|
1272 |
-
check safety for items in files
|
1273 |
-
"""
|
1274 |
-
message = item['prompt'].strip()
|
1275 |
-
|
1276 |
-
if len(message) == 0:
|
1277 |
-
raise gr.Error(f'Prompt {index} empty')
|
1278 |
-
|
1279 |
-
message_safety = safety_check(message, history=None)
|
1280 |
-
if message_safety is not None:
|
1281 |
-
raise gr.Error(f'Prompt {index} invalid: {message_safety}')
|
1282 |
-
|
1283 |
-
tokenizer = llm.get_tokenizer() if llm is not None else None
|
1284 |
-
if tokenizer is None or len(tokenizer.encode(message)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
|
1285 |
-
raise gr.Error(f"Prompt {index} too long, should be less than {BATCH_INFER_MAX_PROMPT_TOKENS} tokens")
|
1286 |
-
|
1287 |
-
|
1288 |
-
def read_validate_json_files(files: Union[str, List[str]]):
|
1289 |
-
files = files if isinstance(files, list) else [files]
|
1290 |
-
filenames = [f.name for f in files]
|
1291 |
-
all_items = []
|
1292 |
-
for fname in filenames:
|
1293 |
-
# check each files
|
1294 |
-
print(f'Reading {fname}')
|
1295 |
-
with open(fname, 'r', encoding='utf-8') as f:
|
1296 |
-
items = json.load(f)
|
1297 |
-
assert isinstance(items, list), f'Data {fname} not list'
|
1298 |
-
assert all(isinstance(x, dict) for x in items), f'item in input file not list'
|
1299 |
-
assert all("prompt" in x for x in items), f'key prompt should be in dict item of input file'
|
1300 |
-
|
1301 |
-
for i, x in enumerate(items):
|
1302 |
-
validate_file_item(fname, i, x)
|
1303 |
|
1304 |
-
all_items.extend(items)
|
1305 |
|
1306 |
-
if len(all_items) > BATCH_INFER_MAX_ITEMS:
|
1307 |
-
raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
|
1308 |
-
|
1309 |
-
return all_items, filenames
|
1310 |
|
1311 |
-
|
1312 |
-
|
1313 |
-
"""remove gradio cache to avoid flooding"""
|
1314 |
import shutil
|
1315 |
-
|
1316 |
-
|
1317 |
-
|
1318 |
-
if exclude_names is None or not any(ef in f for ef in exclude_names):
|
1319 |
-
print(f'Remove: {f}')
|
1320 |
-
os.unlink(os.path.join(root, f))
|
1321 |
-
# for d in dirs:
|
1322 |
-
# # if not any(d in ef for ef in except_files):
|
1323 |
-
# if exclude_names is None or not any(ef in d for ef in exclude_names):
|
1324 |
-
# print(f'Remove d: {d}')
|
1325 |
-
# shutil.rmtree(os.path.join(root, d))
|
1326 |
-
|
1327 |
-
|
1328 |
-
def maybe_upload_batch_set(pred_json_path):
|
1329 |
-
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1330 |
-
|
1331 |
-
if SAVE_LOGS and DATA_SET_REPO_PATH != "":
|
1332 |
try:
|
1333 |
-
|
1334 |
-
|
1335 |
-
|
1336 |
-
|
1337 |
-
|
1338 |
-
path_in_repo=path_in_repo,
|
1339 |
-
repo_id=DATA_SET_REPO_PATH,
|
1340 |
-
token=HF_TOKEN,
|
1341 |
-
repo_type="dataset",
|
1342 |
-
create_pr=True
|
1343 |
-
)
|
1344 |
except Exception as e:
|
1345 |
-
print(
|
1346 |
-
|
1347 |
-
|
1348 |
-
def free_form_prompt(prompt, history=None, system_prompt=None):
|
1349 |
-
return prompt
|
1350 |
-
|
1351 |
-
def batch_inference(
|
1352 |
-
files: Union[str, List[str]],
|
1353 |
-
prompt_mode: str,
|
1354 |
-
temperature: float,
|
1355 |
-
max_tokens: int,
|
1356 |
-
frequency_penalty: float,
|
1357 |
-
presence_penalty: float,
|
1358 |
-
stop_strings: str = "[STOP],<s>,</s>,<|im_start|>",
|
1359 |
-
current_time: Optional[float] = None,
|
1360 |
-
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
1361 |
-
):
|
1362 |
-
"""
|
1363 |
-
Handle file upload batch inference
|
1364 |
-
|
1365 |
-
"""
|
1366 |
-
global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
|
1367 |
-
if DEBUG:
|
1368 |
-
return debug_file_function(
|
1369 |
-
files, prompt_mode, temperature, max_tokens,
|
1370 |
-
presence_penalty, stop_strings, current_time)
|
1371 |
-
|
1372 |
-
from vllm import LLM, SamplingParams
|
1373 |
-
assert llm is not None
|
1374 |
-
# assert system_prompt.strip() != '', f'system prompt is empty'
|
1375 |
-
|
1376 |
-
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
1377 |
-
tokenizer = llm.get_tokenizer()
|
1378 |
-
# force removing all
|
1379 |
-
# NOTE: need to make sure all cached items are removed!!!!!!!!!
|
1380 |
-
vllm_abort(llm)
|
1381 |
-
|
1382 |
-
temperature = float(temperature)
|
1383 |
-
frequency_penalty = float(frequency_penalty)
|
1384 |
-
max_tokens = int(max_tokens)
|
1385 |
-
|
1386 |
-
all_items, filenames = read_validate_json_files(files)
|
1387 |
-
|
1388 |
-
# remove all items in /tmp/gradio/
|
1389 |
-
remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
|
1390 |
-
|
1391 |
-
if prompt_mode == 'chat':
|
1392 |
-
prompt_format_fn = chatml_format
|
1393 |
-
elif prompt_mode == 'few-shot':
|
1394 |
-
from functools import partial
|
1395 |
-
# prompt_format_fn = partial(
|
1396 |
-
# chatml_format, include_end_instruct=False
|
1397 |
-
# )
|
1398 |
-
prompt_format_fn = free_form_prompt
|
1399 |
-
else:
|
1400 |
-
raise gr.Error(f'Wrong mode {prompt_mode}')
|
1401 |
-
|
1402 |
-
full_prompts = [
|
1403 |
-
prompt_format_fn(
|
1404 |
-
x['prompt'], [], sys_prompt=system_prompt
|
1405 |
-
)
|
1406 |
-
for i, x in enumerate(all_items)
|
1407 |
-
]
|
1408 |
-
print(f'{full_prompts[0]}\n')
|
1409 |
-
|
1410 |
-
if any(len(tokenizer.encode(x)) >= 4090 for x in full_prompts):
|
1411 |
-
raise gr.Error(f"Some prompt is too long!")
|
1412 |
-
|
1413 |
-
stop_seq = list(set(['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'] + stop_strings))
|
1414 |
-
sampling_params = SamplingParams(
|
1415 |
-
temperature=temperature,
|
1416 |
-
max_tokens=max_tokens,
|
1417 |
-
frequency_penalty=frequency_penalty,
|
1418 |
-
presence_penalty=presence_penalty,
|
1419 |
-
stop=stop_seq
|
1420 |
-
)
|
1421 |
-
|
1422 |
-
generated = llm.generate(full_prompts, sampling_params, use_tqdm=False)
|
1423 |
-
responses = [g.outputs[0].text for g in generated]
|
1424 |
-
#responses = ["Our system is under maintenance, will be back soon!" for g in generated]
|
1425 |
-
if len(responses) != len(all_items):
|
1426 |
-
raise gr.Error(f'inconsistent lengths {len(responses)} != {len(all_items)}')
|
1427 |
-
|
1428 |
-
for res, item in zip(responses, all_items):
|
1429 |
-
item['response'] = res
|
1430 |
-
|
1431 |
-
save_path = BATCH_INFER_SAVE_TMP_FILE
|
1432 |
-
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
1433 |
-
with open(save_path, 'w', encoding='utf-8') as f:
|
1434 |
-
json.dump(all_items, f, indent=4, ensure_ascii=False)
|
1435 |
-
|
1436 |
-
# You need to upload save_path as a new timestamp file.
|
1437 |
-
maybe_upload_batch_set(save_path)
|
1438 |
-
|
1439 |
-
print_items = all_items[:2]
|
1440 |
-
# print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
|
1441 |
-
return save_path, print_items
|
1442 |
-
|
1443 |
-
|
1444 |
-
# BATCH_INFER_MAX_ITEMS
|
1445 |
-
FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
|
1446 |
-
each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
|
1447 |
-
```
|
1448 |
-
[ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
|
1449 |
-
```
|
1450 |
-
"""
|
1451 |
-
|
1452 |
-
CHAT_EXAMPLES = [
|
1453 |
-
["Hãy giải thích thuyết tương đối rộng."],
|
1454 |
-
["Tolong bantu saya menulis email ke lembaga pemerintah untuk mencari dukungan finansial untuk penelitian AI."],
|
1455 |
-
["แนะนำ 10 จุดหมายปลายทางในกรุงเทพฯ"],
|
1456 |
-
]
|
1457 |
-
|
1458 |
-
|
1459 |
-
# performance items
|
1460 |
-
|
1461 |
-
def create_free_form_generation_demo():
|
1462 |
-
global short_model_path
|
1463 |
-
max_tokens = MAX_TOKENS
|
1464 |
-
temperature = TEMPERATURE
|
1465 |
-
frequence_penalty = FREQUENCE_PENALTY
|
1466 |
-
presence_penalty = PRESENCE_PENALTY
|
1467 |
-
|
1468 |
-
introduction = """
|
1469 |
-
### Free-form | Put any context string (like few-shot prompts)
|
1470 |
-
"""
|
1471 |
-
|
1472 |
-
with gr.Blocks() as demo_free_form:
|
1473 |
-
gr.Markdown(introduction)
|
1474 |
-
|
1475 |
-
with gr.Row():
|
1476 |
-
txt = gr.Textbox(
|
1477 |
-
scale=4,
|
1478 |
-
lines=16,
|
1479 |
-
show_label=False,
|
1480 |
-
placeholder="Enter any free form text and submit",
|
1481 |
-
container=False,
|
1482 |
-
)
|
1483 |
-
with gr.Row():
|
1484 |
-
free_submit_button = gr.Button('Submit')
|
1485 |
-
with gr.Row():
|
1486 |
-
temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
|
1487 |
-
length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
|
1488 |
-
freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
|
1489 |
-
pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
|
1490 |
-
stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
|
1491 |
-
|
1492 |
-
free_submit_button.click(
|
1493 |
-
generate_free_form_stream,
|
1494 |
-
[txt, temp, length, freq_pen, pres_pen, stop_strings],
|
1495 |
-
txt
|
1496 |
-
)
|
1497 |
-
return demo_free_form
|
1498 |
-
|
1499 |
-
|
1500 |
-
|
1501 |
-
def create_file_upload_demo():
|
1502 |
-
temperature = TEMPERATURE
|
1503 |
-
frequence_penalty = FREQUENCE_PENALTY
|
1504 |
-
presence_penalty = PRESENCE_PENALTY
|
1505 |
-
max_tokens = MAX_TOKENS
|
1506 |
-
demo_file_upload = gr.Interface(
|
1507 |
-
batch_inference,
|
1508 |
-
inputs=[
|
1509 |
-
gr.File(file_count='single', file_types=['json']),
|
1510 |
-
gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
|
1511 |
-
gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
|
1512 |
-
gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
|
1513 |
-
gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
|
1514 |
-
gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
|
1515 |
-
gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
|
1516 |
-
gr.Number(value=0, label='current_time', visible=False),
|
1517 |
-
],
|
1518 |
-
outputs=[
|
1519 |
-
# "file",
|
1520 |
-
gr.File(label="Generated file"),
|
1521 |
-
# "json"
|
1522 |
-
gr.JSON(label='Example outputs (display 2 samples)')
|
1523 |
-
],
|
1524 |
-
description=FILE_UPLOAD_DESCRIPTION,
|
1525 |
-
allow_flagging=False,
|
1526 |
-
examples=[
|
1527 |
-
["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "<s>,</s>,<|im_start|>"],
|
1528 |
-
["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "<s>,</s>,<|im_start|>,\\n"]
|
1529 |
-
],
|
1530 |
-
cache_examples=False,
|
1531 |
-
)
|
1532 |
-
return demo_file_upload
|
1533 |
-
|
1534 |
-
|
1535 |
-
def create_chat_demo(title=None, description=None):
|
1536 |
-
sys_prompt = SYSTEM_PROMPT_1
|
1537 |
-
max_tokens = MAX_TOKENS
|
1538 |
-
temperature = TEMPERATURE
|
1539 |
-
frequence_penalty = FREQUENCE_PENALTY
|
1540 |
-
presence_penalty = PRESENCE_PENALTY
|
1541 |
-
|
1542 |
-
demo_chat = gr.ChatInterface(
|
1543 |
-
chat_response_stream_multiturn,
|
1544 |
-
chatbot=ChatBot(
|
1545 |
-
label=MODEL_NAME,
|
1546 |
-
bubble_full_width=False,
|
1547 |
-
latex_delimiters=[
|
1548 |
-
{ "left": "$", "right": "$", "display": False},
|
1549 |
-
{ "left": "$$", "right": "$$", "display": True},
|
1550 |
-
],
|
1551 |
-
show_copy_button=True,
|
1552 |
-
),
|
1553 |
-
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
|
1554 |
-
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1555 |
-
# ! consider preventing the stop button
|
1556 |
-
# stop_btn=None,
|
1557 |
-
title=title,
|
1558 |
-
description=description,
|
1559 |
-
additional_inputs=[
|
1560 |
-
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
1561 |
-
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
1562 |
-
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
1563 |
-
gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
1564 |
-
gr.Textbox(value=sys_prompt, label='System prompt', lines=4, interactive=False),
|
1565 |
-
gr.Number(value=0, label='current_time', visible=False),
|
1566 |
-
# ! Remove the system prompt textbox to avoid jailbreaking
|
1567 |
-
],
|
1568 |
-
examples=CHAT_EXAMPLES,
|
1569 |
-
cache_examples=False
|
1570 |
-
)
|
1571 |
-
return demo_chat
|
1572 |
-
|
1573 |
-
|
1574 |
-
def upload_file(file):
|
1575 |
-
# file_paths = [file.name for file in files]
|
1576 |
-
# return file_paths
|
1577 |
-
return file.name
|
1578 |
-
|
1579 |
-
|
1580 |
-
RAG_DESCRIPTION = """
|
1581 |
-
* Upload a doc below to answer question about it (RAG).
|
1582 |
-
* Every question must be explicit and self-contained! Because each prompt will invoke a new RAG retrieval without considering previous conversations.
|
1583 |
-
(E.g: Dont prompt "Answer my previous question in details.")
|
1584 |
-
"""
|
1585 |
-
|
1586 |
-
def create_chat_demo_rag(title=None, description=None):
|
1587 |
-
sys_prompt = SYSTEM_PROMPT_1
|
1588 |
-
max_tokens = MAX_TOKENS
|
1589 |
-
temperature = TEMPERATURE
|
1590 |
-
frequence_penalty = FREQUENCE_PENALTY
|
1591 |
-
presence_penalty = PRESENCE_PENALTY
|
1592 |
-
description = description or RAG_DESCRIPTION
|
1593 |
-
|
1594 |
-
# with gr.Blocks(title="RAG") as rag_demo:
|
1595 |
-
additional_inputs = [
|
1596 |
-
gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt', 'json']),
|
1597 |
-
# gr.Textbox(value=None, label='Document path', lines=1, interactive=False),
|
1598 |
-
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
1599 |
-
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
1600 |
-
# gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
1601 |
-
# gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
1602 |
-
gr.Textbox(value=sys_prompt, label='System prompt', lines=1, interactive=False),
|
1603 |
-
gr.Number(value=0, label='current_time', visible=False),
|
1604 |
-
]
|
1605 |
-
|
1606 |
-
demo_rag_chat = gr.ChatInterface(
|
1607 |
-
chat_response_stream_rag_multiturn,
|
1608 |
-
chatbot=gr.Chatbot(
|
1609 |
-
label=MODEL_NAME + "-RAG",
|
1610 |
-
bubble_full_width=False,
|
1611 |
-
latex_delimiters=[
|
1612 |
-
{ "left": "$", "right": "$", "display": False},
|
1613 |
-
{ "left": "$$", "right": "$$", "display": True},
|
1614 |
-
],
|
1615 |
-
show_copy_button=True,
|
1616 |
-
),
|
1617 |
-
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
|
1618 |
-
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1619 |
-
# ! consider preventing the stop button
|
1620 |
-
# stop_btn=None,
|
1621 |
-
title=title,
|
1622 |
-
description=description,
|
1623 |
-
additional_inputs=additional_inputs,
|
1624 |
-
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1625 |
-
# examples=CHAT_EXAMPLES,
|
1626 |
-
cache_examples=False
|
1627 |
-
)
|
1628 |
-
# with demo_rag_chat:
|
1629 |
-
# upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt', 'json'], file_count="single")
|
1630 |
-
# upload_button.upload(upload_file, upload_button, additional_inputs[0])
|
1631 |
-
|
1632 |
-
# return demo_chat
|
1633 |
-
return demo_rag_chat
|
1634 |
-
|
1635 |
|
1636 |
|
1637 |
def launch_demo():
|
1638 |
-
global demo,
|
1639 |
model_desc = MODEL_DESC
|
1640 |
model_path = MODEL_PATH
|
1641 |
-
model_title = MODEL_TITLE
|
1642 |
-
hf_model_name = HF_MODEL_NAME
|
1643 |
-
tensor_parallel = TENSOR_PARALLEL
|
1644 |
-
assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
|
1645 |
-
dtype = DTYPE
|
1646 |
-
sys_prompt = SYSTEM_PROMPT_1
|
1647 |
-
max_tokens = MAX_TOKENS
|
1648 |
-
temperature = TEMPERATURE
|
1649 |
-
frequence_penalty = FREQUENCE_PENALTY
|
1650 |
-
presence_penalty = PRESENCE_PENALTY
|
1651 |
-
ckpt_info = "None"
|
1652 |
-
|
1653 |
-
print(
|
1654 |
-
f'Launch config: '
|
1655 |
-
f'\n| model_title=`{model_title}` '
|
1656 |
-
f'\n| max_tokens={max_tokens} '
|
1657 |
-
f'\n| dtype={dtype} '
|
1658 |
-
f'\n| tensor_parallel={tensor_parallel} '
|
1659 |
-
f'\n| IS_DELETE_FOLDER={IS_DELETE_FOLDER} '
|
1660 |
-
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
1661 |
-
f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
|
1662 |
-
f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
|
1663 |
-
f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} '
|
1664 |
-
f'\n| frequence_penalty={frequence_penalty} '
|
1665 |
-
f'\n| presence_penalty={presence_penalty} '
|
1666 |
-
f'\n| temperature={temperature} '
|
1667 |
-
# f'\n| hf_model_name={hf_model_name} '
|
1668 |
-
f'\n| model_path={model_path} '
|
1669 |
-
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
1670 |
-
f'\n| gpu_memory_utilization={gpu_memory_utilization} '
|
1671 |
-
f'\n| LOG_PATH={LOG_PATH} | SAVE_LOGS={SAVE_LOGS} '
|
1672 |
-
f'\n| Desc={model_desc}'
|
1673 |
-
)
|
1674 |
-
|
1675 |
-
if DEBUG:
|
1676 |
-
model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original"
|
1677 |
-
# response_fn = debug_chat_response_echo
|
1678 |
-
response_fn = chat_response_stream_multiturn
|
1679 |
-
print(f'Creating in DEBUG MODE')
|
1680 |
-
if SAVE_LOGS:
|
1681 |
-
LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
|
1682 |
-
else:
|
1683 |
-
# ! load the model
|
1684 |
-
maybe_delete_folder()
|
1685 |
-
|
1686 |
-
if DOWNLOAD_SNAPSHOT:
|
1687 |
-
print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}')
|
1688 |
-
if HF_TOKEN is not None:
|
1689 |
-
print(f'Load with HF_TOKEN: {HF_TOKEN}')
|
1690 |
-
snapshot_download(hf_model_name, local_dir=model_path, use_auth_token=True, token=HF_TOKEN)
|
1691 |
-
else:
|
1692 |
-
snapshot_download(hf_model_name, local_dir=model_path)
|
1693 |
-
|
1694 |
-
import vllm
|
1695 |
-
from vllm import LLM
|
1696 |
-
|
1697 |
-
print(F'VLLM: {vllm.__version__}')
|
1698 |
-
ckpt_info = check_model_path(model_path)
|
1699 |
-
|
1700 |
-
print(f'Load path: {model_path} | {ckpt_info}')
|
1701 |
|
1702 |
-
|
1703 |
-
|
1704 |
-
|
1705 |
-
|
1706 |
-
|
1707 |
-
|
1708 |
-
|
1709 |
-
|
1710 |
-
|
1711 |
-
|
1712 |
-
|
1713 |
-
|
1714 |
-
|
1715 |
-
|
1716 |
-
|
1717 |
-
|
1718 |
-
|
1719 |
-
|
1720 |
-
|
1721 |
-
|
1722 |
-
|
1723 |
-
|
1724 |
-
|
1725 |
-
if SAVE_LOGS:
|
1726 |
-
LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
|
1727 |
-
|
1728 |
-
if ENABLE_BATCH_INFER:
|
1729 |
-
|
1730 |
-
# demo_file_upload = create_file_upload_demo()
|
1731 |
-
|
1732 |
-
demo_free_form = create_free_form_generation_demo()
|
1733 |
-
|
1734 |
-
demo_chat = create_chat_demo()
|
1735 |
-
demo_chat_rag = create_chat_demo_rag(description=RAG_DESCRIPTION)
|
1736 |
-
descriptions = model_desc
|
1737 |
-
if DISPLAY_MODEL_PATH:
|
1738 |
-
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
1739 |
-
|
1740 |
-
demo = CustomTabbedInterface(
|
1741 |
-
interface_list=[
|
1742 |
-
demo_chat,
|
1743 |
-
demo_chat_rag,
|
1744 |
-
demo_free_form,
|
1745 |
-
# demo_file_upload,
|
1746 |
-
],
|
1747 |
-
tab_names=[
|
1748 |
-
"Chat Interface",
|
1749 |
-
"RAG Chat Interface",
|
1750 |
-
"Text completion",
|
1751 |
-
# "Batch Inference",
|
1752 |
-
],
|
1753 |
-
title=f"{model_title}",
|
1754 |
-
description=descriptions,
|
1755 |
)
|
1756 |
-
else:
|
1757 |
-
descriptions = model_desc
|
1758 |
-
if DISPLAY_MODEL_PATH:
|
1759 |
-
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
1760 |
|
1761 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1762 |
demo.title = MODEL_NAME
|
1763 |
|
1764 |
with demo:
|
1765 |
-
|
1766 |
-
try:
|
1767 |
-
from performance_plot import attach_plot_to_demo
|
1768 |
-
attach_plot_to_demo(demo)
|
1769 |
-
except Exception as e:
|
1770 |
-
print(f'Fail to load DEMO plot: {str(e)}')
|
1771 |
-
|
1772 |
-
gr.Markdown(cite_markdown)
|
1773 |
-
if DISPLAY_MODEL_PATH:
|
1774 |
-
gr.Markdown(path_markdown.format(model_path=model_path))
|
1775 |
|
1776 |
-
if ENABLE_AGREE_POPUP:
|
1777 |
-
demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
|
1778 |
-
|
1779 |
-
# login_btn = gr.LoginButton()
|
1780 |
-
|
1781 |
demo.queue(api_open=False)
|
1782 |
return demo
|
1783 |
|
1784 |
|
|
|
1785 |
if __name__ == "__main__":
|
1786 |
demo = launch_demo()
|
1787 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# Description:
|
5 |
"""
|
6 |
+
Demo script to launch Language chat model
|
7 |
"""
|
8 |
|
9 |
|
10 |
import os
|
11 |
+
from gradio.themes import ThemeClass as Theme
|
12 |
import numpy as np
|
13 |
import argparse
|
14 |
+
# import torch
|
15 |
import gradio as gr
|
16 |
from typing import Any, Iterator
|
17 |
from typing import Iterator, List, Optional, Tuple
|
|
|
30 |
from typing import List, Optional, Union, Dict, Tuple
|
31 |
from tqdm.auto import tqdm
|
32 |
from huggingface_hub import snapshot_download
|
33 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
34 |
+
from gradio.components import Button, Component
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
from gradio.events import Dependency, EventListenerMethod
|
36 |
|
37 |
+
from multipurpose_chatbot.demos.base_demo import CustomTabbedInterface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
from multipurpose_chatbot.configs import (
|
40 |
+
MODEL_TITLE,
|
41 |
+
MODEL_DESC,
|
42 |
+
MODEL_INFO,
|
43 |
+
CITE_MARKDOWN,
|
44 |
+
ALLOWED_PATHS,
|
45 |
+
PROXY,
|
46 |
+
PORT,
|
47 |
+
MODEL_PATH,
|
48 |
+
MODEL_NAME,
|
49 |
+
BACKEND,
|
50 |
+
DEMOS,
|
51 |
+
DELETE_FOLDER,
|
52 |
+
)
|
53 |
|
|
|
|
|
54 |
|
55 |
+
demo = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
|
|
57 |
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
if DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER):
|
60 |
+
print(F'WARNING deleting folder: {DELETE_FOLDER}')
|
|
|
61 |
import shutil
|
62 |
+
print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
|
63 |
+
for filename in os.listdir(DELETE_FOLDER):
|
64 |
+
file_path = os.path.join(DELETE_FOLDER, filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
try:
|
66 |
+
if os.path.isfile(file_path) or os.path.islink(file_path):
|
67 |
+
os.unlink(file_path)
|
68 |
+
elif os.path.isdir(file_path):
|
69 |
+
shutil.rmtree(file_path)
|
70 |
+
print(f'deleted: {file_path}')
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
except Exception as e:
|
72 |
+
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
def launch_demo():
|
76 |
+
global demo, MODEL_ENGINE
|
77 |
model_desc = MODEL_DESC
|
78 |
model_path = MODEL_PATH
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
print(f'Begin importing models')
|
81 |
+
from multipurpose_chatbot.demos import get_demo_class
|
82 |
+
|
83 |
+
# demos = {
|
84 |
+
# k: get_demo_class(k)().create_demo()
|
85 |
+
# for k in demo_and_tab_names.keys()
|
86 |
+
# }
|
87 |
+
print(f'{DEMOS=}')
|
88 |
+
demo_class_objects = {
|
89 |
+
k: get_demo_class(k)()
|
90 |
+
for k in DEMOS
|
91 |
+
}
|
92 |
+
demos = {
|
93 |
+
k: get_demo_class(k)().create_demo()
|
94 |
+
for k in DEMOS
|
95 |
+
}
|
96 |
+
demos_names = [x.tab_name for x in demo_class_objects.values()]
|
97 |
+
|
98 |
+
descriptions = model_desc
|
99 |
+
if MODEL_INFO is not None and MODEL_INFO != "":
|
100 |
+
descriptions += (
|
101 |
+
f"<br>" +
|
102 |
+
MODEL_INFO.format(model_path=model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
)
|
|
|
|
|
|
|
|
|
104 |
|
105 |
+
demo = CustomTabbedInterface(
|
106 |
+
interface_list=list(demos.values()),
|
107 |
+
tab_names=demos_names,
|
108 |
+
title=f"{MODEL_TITLE}",
|
109 |
+
description=descriptions,
|
110 |
+
)
|
111 |
+
|
112 |
demo.title = MODEL_NAME
|
113 |
|
114 |
with demo:
|
115 |
+
gr.Markdown(CITE_MARKDOWN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
|
|
|
|
|
|
|
|
|
|
117 |
demo.queue(api_open=False)
|
118 |
return demo
|
119 |
|
120 |
|
121 |
+
|
122 |
if __name__ == "__main__":
|
123 |
demo = launch_demo()
|
124 |
+
if PROXY is not None and PROXY != "":
|
125 |
+
print(f'{PROXY=} {PORT=}')
|
126 |
+
print(f"{ALLOWED_PATHS=}")
|
127 |
+
demo.launch(server_port=PORT, root_path=PROXY, show_api=False, allowed_paths=ALLOWED_PATHS)
|
128 |
+
else:
|
129 |
+
demo.launch(server_port=PORT, show_api=False, allowed_paths=ALLOWED_PATHS)
|
130 |
+
|
131 |
+
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/attention_all_you_need.pdf
ADDED
Binary file (858 kB). View file
|
|
assets/attention_short.pdf
ADDED
Binary file (236 kB). View file
|
|
assets/dog_monalisa.jpeg
ADDED
assets/upload_chat.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"id": "1",
|
4 |
+
"prompt": "Tell me something about AI?"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"id": "2",
|
8 |
+
"prompt": "Who are you?"
|
9 |
+
}
|
10 |
+
]
|
assets/upload_few_shot.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"id": "0",
|
4 |
+
"prompt": "Translate Indonesian to English.\nIndonesian: \"Mereka melakukan hal ini dengan cara memancarkan sebuah partikel kecil cahaya kecil yang biasa disebut \"foton\".\"\nEnglish: They do this by emitting a tiny particle of light called a \"photon\".\n\nTranslate Indonesian to English.\nIndonesian: Kami melewati waktu seperti rangkaian peristiwa yang berlalu dari masa depan hingga masa kini lalu ke masa lalu.\nEnglish: We experience time as a series of events passing from the future through the present to the past.\n\nTranslate Indonesian to English.\nIndonesian: Canyoning (atau: canyoneering) adalah segala aktivitas yang terjadi di dasar ngarai, yang kering atau penuh air.\nEnglish: Canyoning (or: canyoneering) is about going in a bottom of a canyon, which is either dry or full of water.\n\nTranslate Indonesian to English.\nIndonesian: Mohon diingat bahwa intinya Anda sedang berkunjung ke situs kuburan massal, serta situs yang maknanya tak terhitung bagi sejumlah populasi dunia yang signifikan.\nEnglish:"
|
5 |
+
},
|
6 |
+
{
|
7 |
+
"id": "1",
|
8 |
+
"prompt": "Translate Indonesian to English.\nIndonesian: \"Mereka melakukan hal ini dengan cara memancarkan sebuah partikel kecil cahaya kecil yang biasa disebut \"foton\".\"\nEnglish: They do this by emitting a tiny particle of light called a \"photon\".\n\nTranslate Indonesian to English.\nIndonesian: Kami melewati waktu seperti rangkaian peristiwa yang berlalu dari masa depan hingga masa kini lalu ke masa lalu.\nEnglish: We experience time as a series of events passing from the future through the present to the past.\n\nTranslate Indonesian to English.\nIndonesian: Canyoning (atau: canyoneering) adalah segala aktivitas yang terjadi di dasar ngarai, yang kering atau penuh air.\nEnglish: Canyoning (or: canyoneering) is about going in a bottom of a canyon, which is either dry or full of water.\n\nTranslate Indonesian to English.\nIndonesian: Serangga adalah hewan pertama yang menjelajah angkasa. Kemampuan terbangnya membantu mereka menghindari musuh dengan lebih mudah dan mencari makanan dan pasangan dengan lebih efisien.\nEnglish:"
|
9 |
+
}
|
10 |
+
]
|
llama_cpp_requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
llama-cpp-python
|
mlx_requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
mlx
|
2 |
+
mlx-lm
|
multipurpose_chatbot/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
multipurpose_chatbot/__init__.py
ADDED
File without changes
|
multipurpose_chatbot/configs.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
|
4 |
+
# ! UI Markdown information
|
5 |
+
|
6 |
+
MODEL_TITLE = """
|
7 |
+
<img src="file/seammm_2.png" style="
|
8 |
+
max-width: 10em;
|
9 |
+
max-height: 5%;
|
10 |
+
height: 3em;
|
11 |
+
width: 3em;
|
12 |
+
">
|
13 |
+
<div class="text" style="
|
14 |
+
loat: left;
|
15 |
+
padding-bottom: 2%;
|
16 |
+
">
|
17 |
+
SeaLMMM - Large Multilingual Multimodal Models for Southeast Asia
|
18 |
+
</div>
|
19 |
+
"""
|
20 |
+
|
21 |
+
# <a href='https://huggingface.co/spaces/SeaLLMs/SeaLMMM-7b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
22 |
+
# <a href='https://huggingface.co/SeaLLMs/SeaLLM-7B-v2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
|
23 |
+
#
|
24 |
+
MODEL_DESC = f"""
|
25 |
+
<div style='display:flex; gap: 0.25rem; '>
|
26 |
+
<a href='https://github.com/damo-nlp-sg/seallms'><img src='https://img.shields.io/badge/Github-Code-success'></a>
|
27 |
+
<a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-7B'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
28 |
+
<a href='https://huggingface.co/SeaLLMs/SeaLMMM-7B-early'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
|
29 |
+
</div>
|
30 |
+
<span style="font-size: larger">
|
31 |
+
<a href="https://huggingface.co/SeaLLMs/SeaLMMM-7B-early" target="_blank">SeaLMMM-7B-early</a> - multilingual multimodal assistant for Southeast Asia. It handles <b>both</b> text-only (<a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">LLMs</a> and vision instructions (LVMs). <span style="color: red">SeaLMMM-7B has not finished training.</span>
|
32 |
+
</span>
|
33 |
+
<br>
|
34 |
+
<span>
|
35 |
+
<span style="color: red">The chatbot may produce false and harmful content!</span>
|
36 |
+
By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>
|
37 |
+
</span>
|
38 |
+
|
39 |
+
""".strip()
|
40 |
+
|
41 |
+
"""
|
42 |
+
By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>, which includes
|
43 |
+
not to use our service to generate any harmful, inappropriate or illegal content.
|
44 |
+
The service collects user dialogue data for testing and improvement under
|
45 |
+
<a href="https://creativecommons.org/licenses/by/4.0/">(CC-BY)</a> or similar license. So do not enter any personal information!
|
46 |
+
|
47 |
+
"""
|
48 |
+
|
49 |
+
|
50 |
+
# MODEL_INFO = """
|
51 |
+
# <h4 style="display: hidden;">Model Name: {model_path}</h4>
|
52 |
+
# """
|
53 |
+
MODEL_INFO = ""
|
54 |
+
|
55 |
+
CITE_MARKDOWN = """
|
56 |
+
## Citation
|
57 |
+
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
58 |
+
```
|
59 |
+
@article{damonlpsg2023seallm,
|
60 |
+
author = {Xuan-Phi Nguyen*, Wenxuan Zhang*, Xin Li*, Mahani Aljunied*, Zhiqiang Hu, Chenhui Shen^, Yew Ken Chia^, Xingxuan Li, Jianyu Wang, Qingyu Tan, Liying Cheng, Guanzheng Chen, Yue Deng, Sen Yang, Chaoqun Liu, Hang Zhang, Lidong Bing},
|
61 |
+
title = {SeaLLMs - Large Language Models for Southeast Asia},
|
62 |
+
year = 2023,
|
63 |
+
}
|
64 |
+
```
|
65 |
+
|
66 |
+
"""
|
67 |
+
USE_PANEL = bool(int(os.environ.get("USE_PANEL", "1")))
|
68 |
+
CHATBOT_HEIGHT = int(os.environ.get("CHATBOT_HEIGHT", "500"))
|
69 |
+
|
70 |
+
ALLOWED_PATHS = ["seammm_2.png"]
|
71 |
+
|
72 |
+
|
73 |
+
DEMOS = os.environ.get("DEMOS", "")
|
74 |
+
|
75 |
+
DEMOS = DEMOS.split(",") if DEMOS.strip() != "" else [
|
76 |
+
"DocChatInterfaceDemo",
|
77 |
+
"ChatInterfaceDemo",
|
78 |
+
"TextCompletionDemo",
|
79 |
+
# "RagChatInterfaceDemo",
|
80 |
+
# "VisionChatInterfaceDemo",
|
81 |
+
# "VisionDocChatInterfaceDemo",
|
82 |
+
]
|
83 |
+
|
84 |
+
# DEMOS=DocChatInterfaceDemo,ChatInterfaceDemo,RagChatInterfaceDemo,TextCompletionDemo
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
# ! server info
|
89 |
+
|
90 |
+
DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "")
|
91 |
+
PORT = int(os.environ.get("PORT", "7860"))
|
92 |
+
PROXY = os.environ.get("PROXY", "").strip()
|
93 |
+
|
94 |
+
# ! backend info
|
95 |
+
|
96 |
+
BACKEND = os.environ.get("BACKEND", "debug")
|
97 |
+
|
98 |
+
# ! model information
|
99 |
+
# for RAG
|
100 |
+
RAG_EMBED_MODEL_NAME = os.environ.get("RAG_EMBED_MODEL_NAME", "sentence-transformers/all-MiniLM-L6-v2")
|
101 |
+
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1024"))
|
102 |
+
CHUNK_OVERLAP = int(os.environ.get("CHUNK_SIZE", "50"))
|
103 |
+
|
104 |
+
|
105 |
+
SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", """You are a helpful, respectful, honest and safe AI assistant.""")
|
106 |
+
|
107 |
+
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
108 |
+
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
|
109 |
+
# ! these values currently not used
|
110 |
+
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.0"))
|
111 |
+
PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
|
112 |
+
|
113 |
+
|
114 |
+
# Transformers or vllm
|
115 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "mistralai/Mistral-7B-Instruct-v0.2")
|
116 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "Cool-Chatbot")
|
117 |
+
DTYPE = os.environ.get("DTYPE", "bfloat16")
|
118 |
+
DEVICE = os.environ.get("DEVICE", "cuda")
|
119 |
+
|
120 |
+
# VLLM
|
121 |
+
GPU_MEMORY_UTILIZATION = float(os.environ.get("GPU_MEMORY_UTILIZATION", "0.9"))
|
122 |
+
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
123 |
+
QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
|
124 |
+
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
125 |
+
# how many iterations to perform safety check on response
|
126 |
+
STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
|
127 |
+
|
128 |
+
# llama.cpp
|
129 |
+
DEFAULT_CHAT_TEMPLATE = os.environ.get("DEFAULT_CHAT_TEMPLATE", "chatml")
|
130 |
+
N_CTX = int(os.environ.get("N_CTX", "4096"))
|
131 |
+
N_GPU_LAYERS = int(os.environ.get("N_GPU_LAYERS", "-1"))
|
132 |
+
|
133 |
+
# llava.llama.cpp
|
134 |
+
|
135 |
+
|
136 |
+
# Multimodal
|
137 |
+
IMAGE_TOKEN = os.environ.get("IMAGE_TOKEN", "[IMAGE]<|image|>[/IMAGE]")
|
138 |
+
IMAGE_TOKEN_INTERACTIVE = bool(int(os.environ.get("IMAGE_TOKEN_INTERACTIVE", "0")))
|
139 |
+
IMAGE_TOKEN_LENGTH = int(os.environ.get("IMAGE_TOKEN_LENGTH", "576"))
|
140 |
+
MAX_PACHES = int(os.environ.get("MAX_PACHES", "1"))
|
multipurpose_chatbot/demos/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
multipurpose_chatbot/demos/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .base_demo import *
|
3 |
+
|
4 |
+
from .chat_interface import ChatInterfaceDemo
|
5 |
+
from .rag_chat_interface import RagChatInterfaceDemo
|
6 |
+
from .multimodal_chat_interface import *
|
7 |
+
from .text_completion import *
|
8 |
+
from .batch_inference import *
|
9 |
+
from .multimodal_preference_interface import *
|
multipurpose_chatbot/demos/base_demo.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
|
25 |
+
|
26 |
+
def create_class_func_registry():
|
27 |
+
registry = {}
|
28 |
+
def register_registry(cls, exist_ok=False):
|
29 |
+
assert exist_ok or cls.__name__ not in registry, f'{cls} already in registry: {registry}'
|
30 |
+
registry[cls.__name__] = cls
|
31 |
+
return cls
|
32 |
+
|
33 |
+
def get_registry(name):
|
34 |
+
assert name in registry, f'{name} not in registry: {registry}'
|
35 |
+
return registry[name]
|
36 |
+
|
37 |
+
return registry, register_registry, get_registry
|
38 |
+
|
39 |
+
DEMOS, register_demo, get_demo_class = create_class_func_registry()
|
40 |
+
|
41 |
+
|
42 |
+
class BaseDemo(object):
|
43 |
+
"""
|
44 |
+
All demo should be created from BaseDemo and registered with @register_demo
|
45 |
+
"""
|
46 |
+
def __init__(self) -> None:
|
47 |
+
pass
|
48 |
+
|
49 |
+
@property
|
50 |
+
def tab_name(self):
|
51 |
+
return "Demo"
|
52 |
+
|
53 |
+
def create_demo(
|
54 |
+
self,
|
55 |
+
title: Optional[str] = None,
|
56 |
+
description: Optional[str] = None,
|
57 |
+
**kwargs,
|
58 |
+
) -> gr.Blocks:
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
@document()
|
63 |
+
class CustomTabbedInterface(gr.Blocks):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
interface_list: list[gr.Interface],
|
67 |
+
tab_names: Optional[list[str]] = None,
|
68 |
+
title: Optional[str] = None,
|
69 |
+
description: Optional[str] = None,
|
70 |
+
theme: Optional[gr.Theme] = None,
|
71 |
+
analytics_enabled: Optional[bool] = None,
|
72 |
+
css: Optional[str] = None,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Parameters:
|
76 |
+
interface_list: a list of interfaces to be rendered in tabs.
|
77 |
+
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
78 |
+
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
79 |
+
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
80 |
+
css: custom css or path to custom css file to apply to entire Blocks
|
81 |
+
Returns:
|
82 |
+
a Gradio Tabbed Interface for the given interfaces
|
83 |
+
"""
|
84 |
+
super().__init__(
|
85 |
+
title=title or "Gradio",
|
86 |
+
theme=theme,
|
87 |
+
analytics_enabled=analytics_enabled,
|
88 |
+
mode="tabbed_interface",
|
89 |
+
css=css,
|
90 |
+
)
|
91 |
+
self.description = description
|
92 |
+
if tab_names is None:
|
93 |
+
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
94 |
+
with self:
|
95 |
+
if title:
|
96 |
+
gr.Markdown(
|
97 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
|
98 |
+
)
|
99 |
+
if description:
|
100 |
+
gr.Markdown(description)
|
101 |
+
with gr.Tabs():
|
102 |
+
for interface, tab_name in zip(interface_list, tab_names):
|
103 |
+
with gr.Tab(label=tab_name):
|
104 |
+
interface.render()
|
105 |
+
|
multipurpose_chatbot/demos/batch_inference.py
ADDED
File without changes
|
multipurpose_chatbot/demos/chat_interface.py
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
|
25 |
+
|
26 |
+
import inspect
|
27 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
28 |
+
|
29 |
+
import anyio
|
30 |
+
from gradio_client import utils as client_utils
|
31 |
+
from gradio_client.documentation import document
|
32 |
+
|
33 |
+
from gradio.blocks import Blocks
|
34 |
+
from gradio.components import (
|
35 |
+
Button,
|
36 |
+
Chatbot,
|
37 |
+
Component,
|
38 |
+
Markdown,
|
39 |
+
State,
|
40 |
+
Textbox,
|
41 |
+
get_component_instance,
|
42 |
+
)
|
43 |
+
from gradio.events import Dependency, on
|
44 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
45 |
+
from gradio.helpers import special_args
|
46 |
+
from gradio.layouts import Accordion, Group, Row
|
47 |
+
from gradio.routes import Request
|
48 |
+
from gradio.themes import ThemeClass as Theme
|
49 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
50 |
+
|
51 |
+
|
52 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
53 |
+
from ..configs import (
|
54 |
+
SYSTEM_PROMPT,
|
55 |
+
MODEL_NAME,
|
56 |
+
MAX_TOKENS,
|
57 |
+
TEMPERATURE,
|
58 |
+
)
|
59 |
+
|
60 |
+
from ..globals import MODEL_ENGINE
|
61 |
+
|
62 |
+
CHAT_EXAMPLES = [
|
63 |
+
["Explain general relativity."],
|
64 |
+
]
|
65 |
+
DATETIME_FORMAT = "Current date time: {cur_datetime}."
|
66 |
+
|
67 |
+
|
68 |
+
def gradio_history_to_openai_conversations(message=None, history=None, system_prompt=None):
|
69 |
+
conversations = []
|
70 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
71 |
+
if history is not None and len(history) > 0:
|
72 |
+
for i, (prompt, res) in enumerate(history):
|
73 |
+
if prompt is not None:
|
74 |
+
conversations.append({"role": "user", "content": prompt.strip()})
|
75 |
+
if res is not None:
|
76 |
+
conversations.append({"role": "assistant", "content": res.strip()})
|
77 |
+
if message is not None:
|
78 |
+
if len(message.strip()) == 0:
|
79 |
+
raise gr.Error("The message cannot be empty!")
|
80 |
+
conversations.append({"role": "user", "content": message.strip()})
|
81 |
+
if conversations[0]['role'] != 'system':
|
82 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
83 |
+
return conversations
|
84 |
+
|
85 |
+
|
86 |
+
def gradio_history_to_conversation_prompt(message=None, history=None, system_prompt=None):
|
87 |
+
global MODEL_ENGINE
|
88 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
89 |
+
gradio_history_to_openai_conversations(
|
90 |
+
message, history=history, system_prompt=system_prompt),
|
91 |
+
add_generation_prompt=True
|
92 |
+
)
|
93 |
+
return full_prompt
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
def get_datetime_string():
|
98 |
+
from datetime import datetime
|
99 |
+
now = datetime.now()
|
100 |
+
# dd/mm/YY H:M:S
|
101 |
+
dt_string = now.strftime("%B %d, %Y, %H:%M:%S")
|
102 |
+
return dt_string
|
103 |
+
|
104 |
+
|
105 |
+
def format_conversation(history, system_prompt=None):
|
106 |
+
_str = '\n'.join([
|
107 |
+
(
|
108 |
+
f'<<<User>>> {h[0]}\n'
|
109 |
+
f'<<<Asst>>> {h[1]}'
|
110 |
+
)
|
111 |
+
for h in history
|
112 |
+
])
|
113 |
+
if system_prompt is not None:
|
114 |
+
_str = f"<<<System>>> {system_prompt}\n" + _str
|
115 |
+
return _str
|
116 |
+
|
117 |
+
|
118 |
+
def chat_response_stream_multiturn_engine(
|
119 |
+
message: str,
|
120 |
+
history: List[Tuple[str, str]],
|
121 |
+
temperature: float,
|
122 |
+
max_tokens: int,
|
123 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
124 |
+
):
|
125 |
+
global MODEL_ENGINE
|
126 |
+
temperature = float(temperature)
|
127 |
+
# ! remove frequency_penalty
|
128 |
+
# frequency_penalty = float(frequency_penalty)
|
129 |
+
max_tokens = int(max_tokens)
|
130 |
+
message = message.strip()
|
131 |
+
if len(message) == 0:
|
132 |
+
raise gr.Error("The message cannot be empty!")
|
133 |
+
# ! skip safety
|
134 |
+
if DATETIME_FORMAT in system_prompt:
|
135 |
+
# ! This sometime works sometimes dont
|
136 |
+
system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
|
137 |
+
full_prompt = gradio_history_to_conversation_prompt(message.strip(), history=history, system_prompt=system_prompt)
|
138 |
+
# ! length checked
|
139 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
|
140 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
141 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
142 |
+
print(full_prompt)
|
143 |
+
outputs = None
|
144 |
+
response = None
|
145 |
+
num_tokens = -1
|
146 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
147 |
+
prompt=full_prompt,
|
148 |
+
temperature=temperature,
|
149 |
+
max_tokens=max_tokens,
|
150 |
+
)):
|
151 |
+
if isinstance(outputs, tuple):
|
152 |
+
response, num_tokens = outputs
|
153 |
+
else:
|
154 |
+
response, num_tokens = outputs, -1
|
155 |
+
yield response, num_tokens
|
156 |
+
|
157 |
+
if response is not None:
|
158 |
+
yield response, num_tokens
|
159 |
+
|
160 |
+
|
161 |
+
class CustomizedChatInterface(gr.ChatInterface):
|
162 |
+
"""
|
163 |
+
Fixing some issue with chatinterace
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
fn: Callable,
|
169 |
+
*,
|
170 |
+
chatbot: Chatbot | None = None,
|
171 |
+
textbox: Textbox | None = None,
|
172 |
+
additional_inputs: str | Component | list[str | Component] | None = None,
|
173 |
+
additional_inputs_accordion_name: str | None = None,
|
174 |
+
additional_inputs_accordion: str | Accordion | None = None,
|
175 |
+
examples: list[str] | None = None,
|
176 |
+
cache_examples: bool | None = None,
|
177 |
+
title: str | None = None,
|
178 |
+
description: str | None = None,
|
179 |
+
theme: Theme | str | None = None,
|
180 |
+
css: str | None = None,
|
181 |
+
js: str | None = None,
|
182 |
+
head: str | None = None,
|
183 |
+
analytics_enabled: bool | None = None,
|
184 |
+
submit_btn: str | None | Button = "Submit",
|
185 |
+
stop_btn: str | None | Button = "Stop",
|
186 |
+
retry_btn: str | None | Button = "🔄 Retry",
|
187 |
+
undo_btn: str | None | Button = "↩️ Undo",
|
188 |
+
clear_btn: str | None | Button = "🗑️ Clear",
|
189 |
+
autofocus: bool = True,
|
190 |
+
concurrency_limit: int | None | Literal["default"] = "default",
|
191 |
+
fill_height: bool = True,
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
Parameters:
|
195 |
+
fn: The function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
|
196 |
+
chatbot: An instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
|
197 |
+
textbox: An instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
|
198 |
+
additional_inputs: An instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
|
199 |
+
additional_inputs_accordion_name: Deprecated. Will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead.
|
200 |
+
additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
|
201 |
+
examples: Sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
|
202 |
+
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
203 |
+
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
|
204 |
+
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
|
205 |
+
theme: Theme to use, loaded from gradio.themes.
|
206 |
+
css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
|
207 |
+
js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
|
208 |
+
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
|
209 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
|
210 |
+
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
|
211 |
+
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
|
212 |
+
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
|
213 |
+
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
|
214 |
+
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
|
215 |
+
autofocus: If True, autofocuses to the textbox when the page loads.
|
216 |
+
concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
|
217 |
+
fill_height: If True, the chat interface will expand to the height of window.
|
218 |
+
"""
|
219 |
+
try:
|
220 |
+
super(gr.ChatInterface, self).__init__(
|
221 |
+
analytics_enabled=analytics_enabled,
|
222 |
+
mode="chat_interface",
|
223 |
+
css=css,
|
224 |
+
title=title or "Gradio",
|
225 |
+
theme=theme,
|
226 |
+
js=js,
|
227 |
+
head=head,
|
228 |
+
fill_height=fill_height,
|
229 |
+
)
|
230 |
+
except Exception as e:
|
231 |
+
# Handling some old gradio version with out fill_height
|
232 |
+
super(gr.ChatInterface, self).__init__(
|
233 |
+
analytics_enabled=analytics_enabled,
|
234 |
+
mode="chat_interface",
|
235 |
+
css=css,
|
236 |
+
title=title or "Gradio",
|
237 |
+
theme=theme,
|
238 |
+
js=js,
|
239 |
+
head=head,
|
240 |
+
# fill_height=fill_height,
|
241 |
+
)
|
242 |
+
self.concurrency_limit = concurrency_limit
|
243 |
+
self.fn = fn
|
244 |
+
self.is_async = inspect.iscoroutinefunction(
|
245 |
+
self.fn
|
246 |
+
) or inspect.isasyncgenfunction(self.fn)
|
247 |
+
self.is_generator = inspect.isgeneratorfunction(
|
248 |
+
self.fn
|
249 |
+
) or inspect.isasyncgenfunction(self.fn)
|
250 |
+
self.examples = examples
|
251 |
+
if self.space_id and cache_examples is None:
|
252 |
+
self.cache_examples = True
|
253 |
+
else:
|
254 |
+
self.cache_examples = cache_examples or False
|
255 |
+
self.buttons: list[Button | None] = []
|
256 |
+
|
257 |
+
if additional_inputs:
|
258 |
+
if not isinstance(additional_inputs, list):
|
259 |
+
additional_inputs = [additional_inputs]
|
260 |
+
self.additional_inputs = [
|
261 |
+
get_component_instance(i)
|
262 |
+
for i in additional_inputs # type: ignore
|
263 |
+
]
|
264 |
+
else:
|
265 |
+
self.additional_inputs = []
|
266 |
+
if additional_inputs_accordion_name is not None:
|
267 |
+
print(
|
268 |
+
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
269 |
+
)
|
270 |
+
self.additional_inputs_accordion_params = {
|
271 |
+
"label": additional_inputs_accordion_name
|
272 |
+
}
|
273 |
+
if additional_inputs_accordion is None:
|
274 |
+
self.additional_inputs_accordion_params = {
|
275 |
+
"label": "Additional Inputs",
|
276 |
+
"open": False,
|
277 |
+
}
|
278 |
+
elif isinstance(additional_inputs_accordion, str):
|
279 |
+
self.additional_inputs_accordion_params = {
|
280 |
+
"label": additional_inputs_accordion
|
281 |
+
}
|
282 |
+
elif isinstance(additional_inputs_accordion, Accordion):
|
283 |
+
self.additional_inputs_accordion_params = (
|
284 |
+
additional_inputs_accordion.recover_kwargs(
|
285 |
+
additional_inputs_accordion.get_config()
|
286 |
+
)
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
raise ValueError(
|
290 |
+
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
|
291 |
+
)
|
292 |
+
|
293 |
+
with self:
|
294 |
+
if title:
|
295 |
+
Markdown(
|
296 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
297 |
+
)
|
298 |
+
if description:
|
299 |
+
Markdown(description)
|
300 |
+
|
301 |
+
if chatbot:
|
302 |
+
self.chatbot = chatbot.render()
|
303 |
+
else:
|
304 |
+
self.chatbot = Chatbot(
|
305 |
+
label="Chatbot", scale=1, height=200 if fill_height else None
|
306 |
+
)
|
307 |
+
|
308 |
+
with Row():
|
309 |
+
for btn in [retry_btn, undo_btn, clear_btn]:
|
310 |
+
if btn is not None:
|
311 |
+
if isinstance(btn, Button):
|
312 |
+
btn.render()
|
313 |
+
elif isinstance(btn, str):
|
314 |
+
btn = Button(btn, variant="secondary", size="sm")
|
315 |
+
else:
|
316 |
+
raise ValueError(
|
317 |
+
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
318 |
+
)
|
319 |
+
self.buttons.append(btn) # type: ignore
|
320 |
+
|
321 |
+
with Group():
|
322 |
+
with Row():
|
323 |
+
if textbox:
|
324 |
+
textbox.container = False
|
325 |
+
textbox.show_label = False
|
326 |
+
textbox_ = textbox.render()
|
327 |
+
assert isinstance(textbox_, Textbox)
|
328 |
+
self.textbox = textbox_
|
329 |
+
else:
|
330 |
+
self.textbox = Textbox(
|
331 |
+
container=False,
|
332 |
+
show_label=False,
|
333 |
+
label="Message",
|
334 |
+
placeholder="Type a message...",
|
335 |
+
scale=7,
|
336 |
+
autofocus=autofocus,
|
337 |
+
)
|
338 |
+
if submit_btn is not None:
|
339 |
+
if isinstance(submit_btn, Button):
|
340 |
+
submit_btn.render()
|
341 |
+
elif isinstance(submit_btn, str):
|
342 |
+
submit_btn = Button(
|
343 |
+
submit_btn,
|
344 |
+
variant="primary",
|
345 |
+
scale=2,
|
346 |
+
min_width=150,
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
raise ValueError(
|
350 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
351 |
+
)
|
352 |
+
if stop_btn is not None:
|
353 |
+
if isinstance(stop_btn, Button):
|
354 |
+
stop_btn.visible = False
|
355 |
+
stop_btn.render()
|
356 |
+
elif isinstance(stop_btn, str):
|
357 |
+
stop_btn = Button(
|
358 |
+
stop_btn,
|
359 |
+
variant="stop",
|
360 |
+
visible=False,
|
361 |
+
scale=2,
|
362 |
+
min_width=150,
|
363 |
+
)
|
364 |
+
else:
|
365 |
+
raise ValueError(
|
366 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
367 |
+
)
|
368 |
+
self.num_tokens = Textbox(
|
369 |
+
container=False,
|
370 |
+
show_label=False,
|
371 |
+
label="num_tokens",
|
372 |
+
placeholder="0 tokens",
|
373 |
+
scale=1,
|
374 |
+
interactive=False,
|
375 |
+
# autofocus=autofocus,
|
376 |
+
min_width=10
|
377 |
+
)
|
378 |
+
self.buttons.extend([submit_btn, stop_btn]) # type: ignore
|
379 |
+
|
380 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
381 |
+
self.fake_response_textbox = Textbox(label="Response", visible=False)
|
382 |
+
(
|
383 |
+
self.retry_btn,
|
384 |
+
self.undo_btn,
|
385 |
+
self.clear_btn,
|
386 |
+
self.submit_btn,
|
387 |
+
self.stop_btn,
|
388 |
+
) = self.buttons
|
389 |
+
|
390 |
+
if examples:
|
391 |
+
if self.is_generator:
|
392 |
+
examples_fn = self._examples_stream_fn
|
393 |
+
else:
|
394 |
+
examples_fn = self._examples_fn
|
395 |
+
|
396 |
+
self.examples_handler = Examples(
|
397 |
+
examples=examples,
|
398 |
+
inputs=[self.textbox] + self.additional_inputs,
|
399 |
+
outputs=self.chatbot,
|
400 |
+
fn=examples_fn,
|
401 |
+
)
|
402 |
+
|
403 |
+
any_unrendered_inputs = any(
|
404 |
+
not inp.is_rendered for inp in self.additional_inputs
|
405 |
+
)
|
406 |
+
if self.additional_inputs and any_unrendered_inputs:
|
407 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
408 |
+
for input_component in self.additional_inputs:
|
409 |
+
if not input_component.is_rendered:
|
410 |
+
input_component.render()
|
411 |
+
|
412 |
+
# The example caching must happen after the input components have rendered
|
413 |
+
if cache_examples:
|
414 |
+
client_utils.synchronize_async(self.examples_handler.cache)
|
415 |
+
|
416 |
+
self.saved_input = State()
|
417 |
+
self.chatbot_state = (
|
418 |
+
State(self.chatbot.value) if self.chatbot.value else State([])
|
419 |
+
)
|
420 |
+
|
421 |
+
self._setup_events()
|
422 |
+
self._setup_api()
|
423 |
+
|
424 |
+
# replace events so that submit button is disabled during generation, if stop_btn not found
|
425 |
+
# this prevent weird behavior
|
426 |
+
def _setup_stop_events(
|
427 |
+
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
428 |
+
) -> None:
|
429 |
+
from gradio.components import State
|
430 |
+
event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
|
431 |
+
if self.stop_btn and self.is_generator:
|
432 |
+
if self.submit_btn:
|
433 |
+
for event_trigger in event_triggers:
|
434 |
+
event_trigger(
|
435 |
+
lambda: (
|
436 |
+
Button(visible=False),
|
437 |
+
Button(visible=True),
|
438 |
+
),
|
439 |
+
None,
|
440 |
+
[self.submit_btn, self.stop_btn],
|
441 |
+
api_name=False,
|
442 |
+
queue=False,
|
443 |
+
)
|
444 |
+
event_to_cancel.then(
|
445 |
+
lambda: (Button(visible=True), Button(visible=False)),
|
446 |
+
None,
|
447 |
+
[self.submit_btn, self.stop_btn],
|
448 |
+
api_name=False,
|
449 |
+
queue=False,
|
450 |
+
)
|
451 |
+
else:
|
452 |
+
for event_trigger in event_triggers:
|
453 |
+
event_trigger(
|
454 |
+
lambda: Button(visible=True),
|
455 |
+
None,
|
456 |
+
[self.stop_btn],
|
457 |
+
api_name=False,
|
458 |
+
queue=False,
|
459 |
+
)
|
460 |
+
event_to_cancel.then(
|
461 |
+
lambda: Button(visible=False),
|
462 |
+
None,
|
463 |
+
[self.stop_btn],
|
464 |
+
api_name=False,
|
465 |
+
queue=False,
|
466 |
+
)
|
467 |
+
self.stop_btn.click(
|
468 |
+
None,
|
469 |
+
None,
|
470 |
+
None,
|
471 |
+
cancels=event_to_cancel,
|
472 |
+
api_name=False,
|
473 |
+
)
|
474 |
+
else:
|
475 |
+
if self.submit_btn:
|
476 |
+
for event_trigger in event_triggers:
|
477 |
+
event_trigger(
|
478 |
+
lambda: Button(interactive=False),
|
479 |
+
None,
|
480 |
+
[self.submit_btn],
|
481 |
+
api_name=False,
|
482 |
+
queue=False,
|
483 |
+
)
|
484 |
+
event_to_cancel.then(
|
485 |
+
lambda: Button(interactive=True),
|
486 |
+
None,
|
487 |
+
[self.submit_btn],
|
488 |
+
api_name=False,
|
489 |
+
queue=False,
|
490 |
+
)
|
491 |
+
# upon clear, cancel the submit event as well
|
492 |
+
if self.clear_btn:
|
493 |
+
self.clear_btn.click(
|
494 |
+
lambda: ([], [], None, Button(interactive=True)),
|
495 |
+
None,
|
496 |
+
[self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
|
497 |
+
queue=False,
|
498 |
+
api_name=False,
|
499 |
+
cancels=event_to_cancel,
|
500 |
+
)
|
501 |
+
|
502 |
+
def _setup_events(self) -> None:
|
503 |
+
from gradio.components import State
|
504 |
+
has_on = False
|
505 |
+
try:
|
506 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
507 |
+
has_on = True
|
508 |
+
except ImportError as ie:
|
509 |
+
has_on = False
|
510 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
511 |
+
if not self.is_generator:
|
512 |
+
raise NotImplementedError(f'should use generator')
|
513 |
+
|
514 |
+
if has_on:
|
515 |
+
# new version
|
516 |
+
submit_triggers = (
|
517 |
+
[self.textbox.submit, self.submit_btn.click]
|
518 |
+
if self.submit_btn
|
519 |
+
else [self.textbox.submit]
|
520 |
+
)
|
521 |
+
submit_event = (
|
522 |
+
on(
|
523 |
+
submit_triggers,
|
524 |
+
self._clear_and_save_textbox,
|
525 |
+
[self.textbox],
|
526 |
+
[self.textbox, self.saved_input],
|
527 |
+
api_name=False,
|
528 |
+
queue=False,
|
529 |
+
)
|
530 |
+
.then(
|
531 |
+
self._display_input,
|
532 |
+
[self.saved_input, self.chatbot_state],
|
533 |
+
[self.chatbot, self.chatbot_state],
|
534 |
+
api_name=False,
|
535 |
+
queue=False,
|
536 |
+
)
|
537 |
+
.then(
|
538 |
+
submit_fn,
|
539 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
540 |
+
[self.chatbot, self.chatbot_state, self.num_tokens],
|
541 |
+
api_name=False,
|
542 |
+
)
|
543 |
+
)
|
544 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
545 |
+
else:
|
546 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
547 |
+
|
548 |
+
if self.retry_btn:
|
549 |
+
retry_event = (
|
550 |
+
self.retry_btn.click(
|
551 |
+
self._delete_prev_fn,
|
552 |
+
[self.chatbot_state],
|
553 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
554 |
+
api_name=False,
|
555 |
+
queue=False,
|
556 |
+
)
|
557 |
+
.then(
|
558 |
+
self._display_input,
|
559 |
+
[self.saved_input, self.chatbot_state],
|
560 |
+
[self.chatbot, self.chatbot_state],
|
561 |
+
api_name=False,
|
562 |
+
queue=False,
|
563 |
+
)
|
564 |
+
.then(
|
565 |
+
submit_fn,
|
566 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
567 |
+
[self.chatbot, self.chatbot_state, self.num_tokens],
|
568 |
+
api_name=False,
|
569 |
+
)
|
570 |
+
)
|
571 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
572 |
+
|
573 |
+
if self.undo_btn:
|
574 |
+
self.undo_btn.click(
|
575 |
+
self._delete_prev_fn,
|
576 |
+
[self.chatbot_state],
|
577 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
578 |
+
api_name=False,
|
579 |
+
queue=False,
|
580 |
+
).then(
|
581 |
+
lambda x: x,
|
582 |
+
[self.saved_input],
|
583 |
+
[self.textbox],
|
584 |
+
api_name=False,
|
585 |
+
queue=False,
|
586 |
+
)
|
587 |
+
# Reconfigure clear_btn to stop and clear text box
|
588 |
+
|
589 |
+
def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:
|
590 |
+
return "", message
|
591 |
+
|
592 |
+
def _display_input(
|
593 |
+
self, message: str, history: List[List[Union[str, None]]]
|
594 |
+
) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
|
595 |
+
if message is not None and message.strip() != "":
|
596 |
+
history.append([message, None])
|
597 |
+
return history, history
|
598 |
+
|
599 |
+
async def _stream_fn(
|
600 |
+
self,
|
601 |
+
message: str,
|
602 |
+
history_with_input,
|
603 |
+
request: Request,
|
604 |
+
*args,
|
605 |
+
) -> AsyncGenerator:
|
606 |
+
history = history_with_input[:-1]
|
607 |
+
inputs, _, _ = special_args(
|
608 |
+
self.fn, inputs=[message, history, *args], request=request
|
609 |
+
)
|
610 |
+
|
611 |
+
if self.is_async:
|
612 |
+
generator = self.fn(*inputs)
|
613 |
+
else:
|
614 |
+
generator = await anyio.to_thread.run_sync(
|
615 |
+
self.fn, *inputs, limiter=self.limiter
|
616 |
+
)
|
617 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
618 |
+
|
619 |
+
# ! In case of error, yield the previous history & undo any generation before raising error
|
620 |
+
try:
|
621 |
+
first_response_pack = await async_iteration(generator)
|
622 |
+
if isinstance(first_response_pack, (tuple, list)):
|
623 |
+
first_response, num_tokens = first_response_pack
|
624 |
+
else:
|
625 |
+
first_response, num_tokens = first_response_pack, -1
|
626 |
+
update = history + [[message, first_response]]
|
627 |
+
yield update, update, f"{num_tokens} toks"
|
628 |
+
except StopIteration:
|
629 |
+
update = history + [[message, None]]
|
630 |
+
yield update, update, "NaN toks"
|
631 |
+
except Exception as e:
|
632 |
+
yield history, history, "NaN toks"
|
633 |
+
raise e
|
634 |
+
|
635 |
+
try:
|
636 |
+
async for response_pack in generator:
|
637 |
+
if isinstance(response_pack, (tuple, list)):
|
638 |
+
response, num_tokens = response_pack
|
639 |
+
else:
|
640 |
+
response, num_tokens = response_pack, "NaN toks"
|
641 |
+
update = history + [[message, response]]
|
642 |
+
yield update, update, f"{num_tokens} toks"
|
643 |
+
except Exception as e:
|
644 |
+
yield history, history, "NaN toks"
|
645 |
+
raise e
|
646 |
+
|
647 |
+
@register_demo
|
648 |
+
class ChatInterfaceDemo(BaseDemo):
|
649 |
+
@property
|
650 |
+
def tab_name(self):
|
651 |
+
return "Chat"
|
652 |
+
|
653 |
+
def create_demo(
|
654 |
+
self,
|
655 |
+
title: str | None = None,
|
656 |
+
description: str | None = None,
|
657 |
+
**kwargs
|
658 |
+
) -> gr.Blocks:
|
659 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
660 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
661 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
662 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
663 |
+
# frequence_penalty = FREQUENCE_PENALTY
|
664 |
+
# presence_penalty = PRESENCE_PENALTY
|
665 |
+
|
666 |
+
demo_chat = CustomizedChatInterface(
|
667 |
+
chat_response_stream_multiturn_engine,
|
668 |
+
chatbot=gr.Chatbot(
|
669 |
+
label=model_name,
|
670 |
+
bubble_full_width=False,
|
671 |
+
latex_delimiters=[
|
672 |
+
{ "left": "$", "right": "$", "display": False},
|
673 |
+
{ "left": "$$", "right": "$$", "display": True},
|
674 |
+
],
|
675 |
+
show_copy_button=True,
|
676 |
+
),
|
677 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
678 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
679 |
+
title=title,
|
680 |
+
description=description,
|
681 |
+
additional_inputs=[
|
682 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
683 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
684 |
+
# gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
685 |
+
# gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
686 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=4)
|
687 |
+
],
|
688 |
+
examples=CHAT_EXAMPLES,
|
689 |
+
cache_examples=False
|
690 |
+
)
|
691 |
+
return demo_chat
|
692 |
+
|
multipurpose_chatbot/demos/multimodal_chat_interface.py
ADDED
@@ -0,0 +1,1295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
from gradio.components.base import Component
|
25 |
+
|
26 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
27 |
+
|
28 |
+
|
29 |
+
from .chat_interface import (
|
30 |
+
SYSTEM_PROMPT,
|
31 |
+
MODEL_NAME,
|
32 |
+
MAX_TOKENS,
|
33 |
+
TEMPERATURE,
|
34 |
+
CHAT_EXAMPLES,
|
35 |
+
gradio_history_to_openai_conversations,
|
36 |
+
gradio_history_to_conversation_prompt,
|
37 |
+
DATETIME_FORMAT,
|
38 |
+
get_datetime_string,
|
39 |
+
chat_response_stream_multiturn_engine,
|
40 |
+
ChatInterfaceDemo,
|
41 |
+
CustomizedChatInterface,
|
42 |
+
)
|
43 |
+
|
44 |
+
from gradio.events import Events
|
45 |
+
|
46 |
+
import inspect
|
47 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
48 |
+
|
49 |
+
import anyio
|
50 |
+
from gradio_client import utils as client_utils
|
51 |
+
from gradio_client.documentation import document
|
52 |
+
|
53 |
+
from gradio.blocks import Blocks
|
54 |
+
from gradio.components import (
|
55 |
+
Button,
|
56 |
+
Chatbot,
|
57 |
+
Component,
|
58 |
+
Markdown,
|
59 |
+
State,
|
60 |
+
Textbox,
|
61 |
+
get_component_instance,
|
62 |
+
)
|
63 |
+
from gradio.events import Dependency, on
|
64 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
65 |
+
from gradio.helpers import special_args
|
66 |
+
from gradio.layouts import Accordion, Group, Row
|
67 |
+
from gradio.routes import Request
|
68 |
+
from gradio.themes import ThemeClass as Theme
|
69 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
70 |
+
|
71 |
+
from ..globals import MODEL_ENGINE
|
72 |
+
|
73 |
+
from ..configs import (
|
74 |
+
USE_PANEL,
|
75 |
+
IMAGE_TOKEN,
|
76 |
+
IMAGE_TOKEN_INTERACTIVE,
|
77 |
+
CHATBOT_HEIGHT,
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
CSS = """
|
83 |
+
.message-fit {
|
84 |
+
min-width: 20em;
|
85 |
+
width: fit-content !important;
|
86 |
+
}
|
87 |
+
|
88 |
+
.message.svelte-1lcyrx4.svelte-1lcyrx4.svelte-1lcyrx4 {
|
89 |
+
padding-top: 1em;
|
90 |
+
padding-bottom: 1em;
|
91 |
+
}
|
92 |
+
"""
|
93 |
+
|
94 |
+
|
95 |
+
DOC_TEMPLATE = """###
|
96 |
+
{content}
|
97 |
+
###
|
98 |
+
|
99 |
+
"""
|
100 |
+
|
101 |
+
DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \
|
102 |
+
If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
|
103 |
+
"""
|
104 |
+
|
105 |
+
|
106 |
+
def undo_history(history):
|
107 |
+
if len(history) == 0:
|
108 |
+
return history
|
109 |
+
if history[-1][-1] is not None:
|
110 |
+
if history[-1][0] is not None:
|
111 |
+
history[-1][-1] = None
|
112 |
+
else:
|
113 |
+
history = history[:-1]
|
114 |
+
else:
|
115 |
+
history = history[:-1]
|
116 |
+
return history
|
117 |
+
|
118 |
+
|
119 |
+
def undo_history_until_last_assistant_turn(history):
|
120 |
+
history = undo_history(history)
|
121 |
+
while len(history) > 0 and history[-1][-1] is None:
|
122 |
+
history = undo_history(history)
|
123 |
+
return history, history
|
124 |
+
|
125 |
+
|
126 |
+
class MultiModalChatInterface(CustomizedChatInterface):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
fn: Callable,
|
130 |
+
*,
|
131 |
+
chatbot: Chatbot | None = None,
|
132 |
+
textbox: Textbox | None = None,
|
133 |
+
additional_inputs: str | Component | list[str | Component] | None = None,
|
134 |
+
additional_inputs_accordion_name: str | None = None,
|
135 |
+
additional_inputs_accordion: str | Accordion | None = None,
|
136 |
+
add_multimodal_fn: Callable | None = None,
|
137 |
+
render_additional_inputs_fn: Callable | None = None,
|
138 |
+
examples: list[str] | None = None,
|
139 |
+
cache_examples: bool | None = None,
|
140 |
+
title: str | None = None,
|
141 |
+
description: str | None = None,
|
142 |
+
theme: Theme | str | None = None,
|
143 |
+
css: str | None = None,
|
144 |
+
js: str | None = None,
|
145 |
+
head: str | None = None,
|
146 |
+
analytics_enabled: bool | None = None,
|
147 |
+
submit_btn: str | None | Button = "Submit",
|
148 |
+
stop_btn: str | None | Button = "Stop",
|
149 |
+
retry_btn: str | None | Button = "🔄 Retry",
|
150 |
+
undo_btn: str | None | Button = "↩️ Undo",
|
151 |
+
clear_btn: str | None | Button = "🗑️ Clear",
|
152 |
+
autofocus: bool = True,
|
153 |
+
concurrency_limit: int | None | Literal["default"] = "default",
|
154 |
+
fill_height: bool = True,
|
155 |
+
):
|
156 |
+
"""
|
157 |
+
Parameters:
|
158 |
+
fn: The function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
|
159 |
+
chatbot: An instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
|
160 |
+
textbox: An instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
|
161 |
+
additional_inputs: An instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
|
162 |
+
additional_inputs_accordion_name: Deprecated. Will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead.
|
163 |
+
additional_inputs_accordion: If a string is provided, this is the label of the `gr.Accordion` to use to contain additional inputs. A `gr.Accordion` object can be provided as well to configure other properties of the container holding the additional inputs. Defaults to a `gr.Accordion(label="Additional Inputs", open=False)`. This parameter is only used if `additional_inputs` is provided.
|
164 |
+
examples: Sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
|
165 |
+
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
|
166 |
+
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
|
167 |
+
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
|
168 |
+
theme: Theme to use, loaded from gradio.themes.
|
169 |
+
css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
|
170 |
+
js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
|
171 |
+
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
|
172 |
+
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
|
173 |
+
submit_btn: Text to display on the submit button. If None, no button will be displayed. If a Button object, that button will be used.
|
174 |
+
stop_btn: Text to display on the stop button, which replaces the submit_btn when the submit_btn or retry_btn is clicked and response is streaming. Clicking on the stop_btn will halt the chatbot response. If set to None, stop button functionality does not appear in the chatbot. If a Button object, that button will be used as the stop button.
|
175 |
+
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
|
176 |
+
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
|
177 |
+
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
|
178 |
+
autofocus: If True, autofocuses to the textbox when the page loads.
|
179 |
+
concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
|
180 |
+
fill_height: If True, the chat interface will expand to the height of window.
|
181 |
+
"""
|
182 |
+
try:
|
183 |
+
super(gr.ChatInterface, self).__init__(
|
184 |
+
analytics_enabled=analytics_enabled,
|
185 |
+
mode="chat_interface",
|
186 |
+
css=css,
|
187 |
+
title=title or "Gradio",
|
188 |
+
theme=theme,
|
189 |
+
js=js,
|
190 |
+
head=head,
|
191 |
+
fill_height=fill_height,
|
192 |
+
)
|
193 |
+
except Exception as e:
|
194 |
+
# Handle old gradio versions without fill_height
|
195 |
+
super(gr.ChatInterface, self).__init__(
|
196 |
+
analytics_enabled=analytics_enabled,
|
197 |
+
mode="chat_interface",
|
198 |
+
css=css,
|
199 |
+
title=title or "Gradio",
|
200 |
+
theme=theme,
|
201 |
+
js=js,
|
202 |
+
head=head,
|
203 |
+
# fill_height=fill_height,
|
204 |
+
)
|
205 |
+
|
206 |
+
self.concurrency_limit = concurrency_limit
|
207 |
+
self.fn = fn
|
208 |
+
self.add_multimodal_fn = add_multimodal_fn
|
209 |
+
self.render_additional_inputs_fn = render_additional_inputs_fn
|
210 |
+
self.multimodal_inputs = []
|
211 |
+
self.is_async = inspect.iscoroutinefunction(
|
212 |
+
self.fn
|
213 |
+
) or inspect.isasyncgenfunction(self.fn)
|
214 |
+
self.is_generator = inspect.isgeneratorfunction(
|
215 |
+
self.fn
|
216 |
+
) or inspect.isasyncgenfunction(self.fn)
|
217 |
+
self.examples = examples
|
218 |
+
if self.space_id and cache_examples is None:
|
219 |
+
self.cache_examples = True
|
220 |
+
else:
|
221 |
+
self.cache_examples = cache_examples or False
|
222 |
+
self.buttons: list[Button | None] = []
|
223 |
+
|
224 |
+
if additional_inputs:
|
225 |
+
if not isinstance(additional_inputs, list):
|
226 |
+
additional_inputs = [additional_inputs]
|
227 |
+
self.additional_inputs = [
|
228 |
+
get_component_instance(i)
|
229 |
+
for i in additional_inputs # type: ignore
|
230 |
+
]
|
231 |
+
else:
|
232 |
+
self.additional_inputs = []
|
233 |
+
if additional_inputs_accordion_name is not None:
|
234 |
+
print(
|
235 |
+
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
236 |
+
)
|
237 |
+
self.additional_inputs_accordion_params = {
|
238 |
+
"label": additional_inputs_accordion_name
|
239 |
+
}
|
240 |
+
if additional_inputs_accordion is None:
|
241 |
+
self.additional_inputs_accordion_params = {
|
242 |
+
"label": "Additional Inputs",
|
243 |
+
"open": False,
|
244 |
+
}
|
245 |
+
elif isinstance(additional_inputs_accordion, str):
|
246 |
+
self.additional_inputs_accordion_params = {
|
247 |
+
"label": additional_inputs_accordion
|
248 |
+
}
|
249 |
+
elif isinstance(additional_inputs_accordion, Accordion):
|
250 |
+
self.additional_inputs_accordion_params = (
|
251 |
+
additional_inputs_accordion.recover_kwargs(
|
252 |
+
additional_inputs_accordion.get_config()
|
253 |
+
)
|
254 |
+
)
|
255 |
+
else:
|
256 |
+
raise ValueError(
|
257 |
+
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
|
258 |
+
)
|
259 |
+
|
260 |
+
with self:
|
261 |
+
if title:
|
262 |
+
Markdown(
|
263 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
264 |
+
)
|
265 |
+
if description:
|
266 |
+
Markdown(description)
|
267 |
+
|
268 |
+
if chatbot:
|
269 |
+
self.chatbot = chatbot.render()
|
270 |
+
else:
|
271 |
+
self.chatbot = Chatbot(
|
272 |
+
label="Chatbot", scale=1, height=200 if fill_height else None
|
273 |
+
)
|
274 |
+
|
275 |
+
with Row():
|
276 |
+
for btn in [retry_btn, undo_btn, clear_btn]:
|
277 |
+
if btn is not None:
|
278 |
+
if isinstance(btn, Button):
|
279 |
+
btn.render()
|
280 |
+
elif isinstance(btn, str):
|
281 |
+
btn = Button(btn, variant="secondary", size="sm")
|
282 |
+
else:
|
283 |
+
raise ValueError(
|
284 |
+
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
285 |
+
)
|
286 |
+
self.buttons.append(btn) # type: ignore
|
287 |
+
|
288 |
+
with Group():
|
289 |
+
with Row():
|
290 |
+
if textbox:
|
291 |
+
textbox.container = False
|
292 |
+
textbox.show_label = False
|
293 |
+
textbox_ = textbox.render()
|
294 |
+
assert isinstance(textbox_, Textbox)
|
295 |
+
self.textbox = textbox_
|
296 |
+
else:
|
297 |
+
self.textbox = Textbox(
|
298 |
+
container=False,
|
299 |
+
show_label=False,
|
300 |
+
label="Message",
|
301 |
+
placeholder="Type a message...",
|
302 |
+
scale=7,
|
303 |
+
autofocus=autofocus,
|
304 |
+
)
|
305 |
+
if submit_btn is not None:
|
306 |
+
if isinstance(submit_btn, Button):
|
307 |
+
submit_btn.render()
|
308 |
+
elif isinstance(submit_btn, str):
|
309 |
+
submit_btn = Button(
|
310 |
+
submit_btn,
|
311 |
+
variant="primary",
|
312 |
+
scale=2,
|
313 |
+
min_width=150,
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
raise ValueError(
|
317 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
318 |
+
)
|
319 |
+
if stop_btn is not None:
|
320 |
+
if isinstance(stop_btn, Button):
|
321 |
+
stop_btn.visible = False
|
322 |
+
stop_btn.render()
|
323 |
+
elif isinstance(stop_btn, str):
|
324 |
+
stop_btn = Button(
|
325 |
+
stop_btn,
|
326 |
+
variant="stop",
|
327 |
+
visible=False,
|
328 |
+
scale=2,
|
329 |
+
min_width=150,
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
raise ValueError(
|
333 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
334 |
+
)
|
335 |
+
self.num_tokens = Textbox(
|
336 |
+
container=False,
|
337 |
+
show_label=False,
|
338 |
+
label="num_tokens",
|
339 |
+
placeholder="0 tokens",
|
340 |
+
scale=1,
|
341 |
+
interactive=False,
|
342 |
+
# autofocus=autofocus,
|
343 |
+
min_width=10
|
344 |
+
)
|
345 |
+
self.buttons.extend([submit_btn, stop_btn]) # type: ignore
|
346 |
+
|
347 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
348 |
+
self.fake_response_textbox = Textbox(label="Response", visible=False)
|
349 |
+
(
|
350 |
+
self.retry_btn,
|
351 |
+
self.undo_btn,
|
352 |
+
self.clear_btn,
|
353 |
+
self.submit_btn,
|
354 |
+
self.stop_btn,
|
355 |
+
) = self.buttons
|
356 |
+
|
357 |
+
|
358 |
+
any_unrendered_inputs = any(
|
359 |
+
not inp.is_rendered for inp in self.additional_inputs
|
360 |
+
)
|
361 |
+
if self.add_multimodal_fn is not None:
|
362 |
+
with Row():
|
363 |
+
self.multimodal_inputs = self.add_multimodal_fn()
|
364 |
+
if self.additional_inputs and any_unrendered_inputs:
|
365 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
366 |
+
if self.render_additional_inputs_fn is not None:
|
367 |
+
self.render_additional_inputs_fn()
|
368 |
+
else:
|
369 |
+
for input_component in self.additional_inputs:
|
370 |
+
if not input_component.is_rendered:
|
371 |
+
input_component.render()
|
372 |
+
else:
|
373 |
+
if self.additional_inputs and any_unrendered_inputs:
|
374 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
375 |
+
if self.render_additional_inputs_fn is not None:
|
376 |
+
self.render_additional_inputs_fn()
|
377 |
+
else:
|
378 |
+
for input_component in self.additional_inputs:
|
379 |
+
if not input_component.is_rendered:
|
380 |
+
input_component.render()
|
381 |
+
|
382 |
+
if examples:
|
383 |
+
if self.is_generator:
|
384 |
+
examples_fn = self._examples_stream_fn
|
385 |
+
else:
|
386 |
+
# examples_fn = self._examples_fn
|
387 |
+
raise NotImplementedError(f'Not streaming not impl')
|
388 |
+
|
389 |
+
self.examples_handler = Examples(
|
390 |
+
examples=examples,
|
391 |
+
inputs=[self.textbox] + self.multimodal_inputs + self.additional_inputs,
|
392 |
+
outputs=self.chatbot,
|
393 |
+
fn=examples_fn,
|
394 |
+
)
|
395 |
+
|
396 |
+
# The example caching must happen after the input components have rendered
|
397 |
+
if cache_examples:
|
398 |
+
client_utils.synchronize_async(self.examples_handler.cache)
|
399 |
+
|
400 |
+
self.saved_input = State()
|
401 |
+
self.chatbot_state = (
|
402 |
+
State(self.chatbot.value) if self.chatbot.value else State([])
|
403 |
+
)
|
404 |
+
|
405 |
+
self._setup_events()
|
406 |
+
self._setup_api()
|
407 |
+
|
408 |
+
def _clear_and_save_textbox(self, message: str, *multimodal_inputs) -> tuple[str, str]:
|
409 |
+
saved_input = [message] + list(multimodal_inputs)
|
410 |
+
outputs = [''] + [None] * len(multimodal_inputs)
|
411 |
+
return outputs + [saved_input]
|
412 |
+
|
413 |
+
def _add_inputs_to_history(self, history: List[List[Union[str, None]]], *args):
|
414 |
+
message = args[0]
|
415 |
+
multimodal_inputs = args[1:1 + len(self.multimodal_inputs)] if len(args) > 1 else None
|
416 |
+
if multimodal_inputs is not None:
|
417 |
+
is_file_exists = [(x is not None and os.path.exists(x)) for x in multimodal_inputs]
|
418 |
+
if any(is_file_exists):
|
419 |
+
file_exists = [f for f, ise in zip(multimodal_inputs, is_file_exists) if ise]
|
420 |
+
if len(file_exists) > 1:
|
421 |
+
raise gr.Error(f"Cannot have more than 1 multimodal input at a time.")
|
422 |
+
fname = file_exists[0]
|
423 |
+
history.append([(fname,), None])
|
424 |
+
if message is not None and message.strip() != "":
|
425 |
+
history.append([message, None])
|
426 |
+
return history
|
427 |
+
|
428 |
+
|
429 |
+
def _display_input(
|
430 |
+
self, saved_input: List[str], history: List[List[Union[str, None]]]
|
431 |
+
) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
|
432 |
+
# message = saved_input[0]
|
433 |
+
# multimodal_inputs = saved_input[1:] if len(saved_input) > 1 else None
|
434 |
+
# # ! If things wrong, return original history and give warning
|
435 |
+
# if multimodal_inputs is not None:
|
436 |
+
# is_file_exists = [(x is not None and os.path.exists(x)) for x in multimodal_inputs]
|
437 |
+
# if any(is_file_exists):
|
438 |
+
# file_exists = [f for f, ise in zip(multimodal_inputs, is_file_exists) if ise]
|
439 |
+
# if len(file_exists) > 1:
|
440 |
+
# raise gr.Error(f"Cannot have more than 1 multimodal input at a time.")
|
441 |
+
# fname = file_exists[0]
|
442 |
+
# history.append([(fname,), None])
|
443 |
+
# if message is not None and message.strip() != "":
|
444 |
+
# history.append([message, None])
|
445 |
+
history = self._add_inputs_to_history(history, *saved_input)
|
446 |
+
return history, history
|
447 |
+
|
448 |
+
def _delete_prev_fn(
|
449 |
+
self, history: list[list[str | None]]
|
450 |
+
) -> tuple[list[list[str | None]], str, list[list[str | None]]]:
|
451 |
+
try:
|
452 |
+
message, _ = history.pop()
|
453 |
+
except IndexError:
|
454 |
+
message = ""
|
455 |
+
saved_input = [message or ""] + [None] * len(self.multimodal_inputs)
|
456 |
+
return history, saved_input, history
|
457 |
+
|
458 |
+
def _setup_events(self) -> None:
|
459 |
+
from gradio.components import State
|
460 |
+
has_on = False
|
461 |
+
try:
|
462 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
463 |
+
has_on = True
|
464 |
+
except ImportError as ie:
|
465 |
+
has_on = False
|
466 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
467 |
+
if not self.is_generator:
|
468 |
+
raise NotImplementedError(f'should use generator')
|
469 |
+
|
470 |
+
if has_on:
|
471 |
+
# new version
|
472 |
+
submit_triggers = (
|
473 |
+
[self.textbox.submit, self.submit_btn.click]
|
474 |
+
if self.submit_btn
|
475 |
+
else [self.textbox.submit]
|
476 |
+
)
|
477 |
+
submit_event = (
|
478 |
+
on(
|
479 |
+
submit_triggers,
|
480 |
+
self._clear_and_save_textbox,
|
481 |
+
[self.textbox] + self.multimodal_inputs,
|
482 |
+
[self.textbox] + self.multimodal_inputs + [self.saved_input],
|
483 |
+
api_name=False,
|
484 |
+
queue=False,
|
485 |
+
)
|
486 |
+
.then(
|
487 |
+
self._display_input,
|
488 |
+
[self.saved_input, self.chatbot_state],
|
489 |
+
[self.chatbot, self.chatbot_state],
|
490 |
+
api_name=False,
|
491 |
+
queue=False,
|
492 |
+
)
|
493 |
+
.success(
|
494 |
+
submit_fn,
|
495 |
+
[self.chatbot_state] + self.additional_inputs,
|
496 |
+
[self.chatbot, self.chatbot_state, self.num_tokens],
|
497 |
+
api_name=False,
|
498 |
+
)
|
499 |
+
)
|
500 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
501 |
+
else:
|
502 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
503 |
+
|
504 |
+
if self.retry_btn:
|
505 |
+
retry_event = (
|
506 |
+
self.retry_btn.click(
|
507 |
+
self._delete_prev_fn,
|
508 |
+
[self.chatbot_state],
|
509 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
510 |
+
api_name=False,
|
511 |
+
queue=False,
|
512 |
+
)
|
513 |
+
.then(
|
514 |
+
self._display_input,
|
515 |
+
[self.saved_input, self.chatbot_state],
|
516 |
+
[self.chatbot, self.chatbot_state],
|
517 |
+
api_name=False,
|
518 |
+
queue=False,
|
519 |
+
)
|
520 |
+
.success(
|
521 |
+
submit_fn,
|
522 |
+
[self.chatbot_state] + self.additional_inputs,
|
523 |
+
[self.chatbot, self.chatbot_state, self.num_tokens],
|
524 |
+
api_name=False,
|
525 |
+
)
|
526 |
+
)
|
527 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
528 |
+
|
529 |
+
if self.undo_btn:
|
530 |
+
self.undo_btn.click(
|
531 |
+
# self._delete_prev_fn,
|
532 |
+
# [self.chatbot_state],
|
533 |
+
# [self.chatbot, self.saved_input, self.chatbot_state],
|
534 |
+
undo_history_until_last_assistant_turn,
|
535 |
+
[self.chatbot_state],
|
536 |
+
[self.chatbot, self.chatbot_state],
|
537 |
+
api_name=False,
|
538 |
+
queue=False,
|
539 |
+
)
|
540 |
+
# .then(
|
541 |
+
# lambda x: x,
|
542 |
+
# [self.saved_input],
|
543 |
+
# [self.textbox],
|
544 |
+
# api_name=False,
|
545 |
+
# queue=False,
|
546 |
+
# )
|
547 |
+
|
548 |
+
async def _stream_fn(
|
549 |
+
self,
|
550 |
+
# message: str,
|
551 |
+
history_with_input,
|
552 |
+
request: Request,
|
553 |
+
*args,
|
554 |
+
) -> AsyncGenerator:
|
555 |
+
history = history_with_input[:-1]
|
556 |
+
message = history_with_input[-1][0]
|
557 |
+
inputs, _, _ = special_args(
|
558 |
+
self.fn, inputs=[history_with_input, *args], request=request
|
559 |
+
)
|
560 |
+
|
561 |
+
if self.is_async:
|
562 |
+
generator = self.fn(*inputs)
|
563 |
+
else:
|
564 |
+
generator = await anyio.to_thread.run_sync(
|
565 |
+
self.fn, *inputs, limiter=self.limiter
|
566 |
+
)
|
567 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
568 |
+
|
569 |
+
# ! In case of error, yield the previous history & undo any generation before raising error
|
570 |
+
try:
|
571 |
+
first_response_pack = await async_iteration(generator)
|
572 |
+
if isinstance(first_response_pack, (tuple, list)):
|
573 |
+
first_response, num_tokens = first_response_pack
|
574 |
+
else:
|
575 |
+
first_response, num_tokens = first_response_pack, -1
|
576 |
+
update = history + [[message, first_response]]
|
577 |
+
yield update, update, f"{num_tokens} toks"
|
578 |
+
except StopIteration:
|
579 |
+
update = history + [[message, None]]
|
580 |
+
yield update, update, "NaN toks"
|
581 |
+
except Exception as e:
|
582 |
+
yield history, history, "NaN toks"
|
583 |
+
raise e
|
584 |
+
|
585 |
+
try:
|
586 |
+
async for response_pack in generator:
|
587 |
+
if isinstance(response_pack, (tuple, list)):
|
588 |
+
response, num_tokens = response_pack
|
589 |
+
else:
|
590 |
+
response, num_tokens = response_pack, "NaN toks"
|
591 |
+
update = history + [[message, response]]
|
592 |
+
yield update, update, f"{num_tokens} toks"
|
593 |
+
except Exception as e:
|
594 |
+
yield history, history, "NaN toks"
|
595 |
+
raise e
|
596 |
+
|
597 |
+
async def _examples_stream_fn(
|
598 |
+
self,
|
599 |
+
# message: str,
|
600 |
+
*args,
|
601 |
+
) -> AsyncGenerator:
|
602 |
+
history = []
|
603 |
+
input_len = 1 + len(self.multimodal_inputs)
|
604 |
+
saved_input = args[:input_len]
|
605 |
+
message = saved_input[0]
|
606 |
+
additional_inputs = [] if len(args) <= input_len else args[input_len:]
|
607 |
+
history = self._add_inputs_to_history(history, *saved_input)
|
608 |
+
inputs, _, _ = special_args(self.fn, inputs=[history, *additional_inputs], request=None)
|
609 |
+
|
610 |
+
if self.is_async:
|
611 |
+
generator = self.fn(*inputs)
|
612 |
+
else:
|
613 |
+
generator = await anyio.to_thread.run_sync(
|
614 |
+
self.fn, *inputs, limiter=self.limiter
|
615 |
+
)
|
616 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
617 |
+
# async for response in generator:
|
618 |
+
# yield [[message, response]]
|
619 |
+
|
620 |
+
try:
|
621 |
+
async for response_pack in generator:
|
622 |
+
if isinstance(response_pack, (tuple, list)):
|
623 |
+
response, num_tokens = response_pack
|
624 |
+
else:
|
625 |
+
response, num_tokens = response_pack, "NaN toks"
|
626 |
+
update = history + [[message, response]]
|
627 |
+
yield update, update, f"{num_tokens} toks"
|
628 |
+
except Exception as e:
|
629 |
+
yield history, history, "NaN toks"
|
630 |
+
raise e
|
631 |
+
|
632 |
+
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
633 |
+
raise NotImplementedError
|
634 |
+
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
635 |
+
|
636 |
+
if self.is_async:
|
637 |
+
response = await self.fn(*inputs)
|
638 |
+
else:
|
639 |
+
response = await anyio.to_thread.run_sync(
|
640 |
+
self.fn, *inputs, limiter=self.limiter
|
641 |
+
)
|
642 |
+
return [[message, response]]
|
643 |
+
|
644 |
+
|
645 |
+
|
646 |
+
def gradio_history_to_openai_conversations(message=None, history=None, system_prompt=None):
|
647 |
+
conversations = []
|
648 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
649 |
+
if history is not None and len(history) > 0:
|
650 |
+
for i, (prompt, res) in enumerate(history):
|
651 |
+
if prompt is not None:
|
652 |
+
conversations.append({"role": "user", "content": prompt.strip()})
|
653 |
+
if res is not None:
|
654 |
+
conversations.append({"role": "assistant", "content": res.strip()})
|
655 |
+
if message is not None:
|
656 |
+
if len(message.strip()) == 0:
|
657 |
+
raise gr.Error("The message cannot be empty!")
|
658 |
+
conversations.append({"role": "user", "content": message.strip()})
|
659 |
+
if conversations[0]['role'] != 'system':
|
660 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
661 |
+
return conversations
|
662 |
+
|
663 |
+
|
664 |
+
def gradio_history_to_conversation_prompt(message=None, history=None, system_prompt=None):
|
665 |
+
global MODEL_ENGINE
|
666 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
667 |
+
gradio_history_to_openai_conversations(
|
668 |
+
message, history=history, system_prompt=system_prompt),
|
669 |
+
add_generation_prompt=True
|
670 |
+
)
|
671 |
+
return full_prompt
|
672 |
+
|
673 |
+
|
674 |
+
def gradio_history_to_vision_conversations_paths(
|
675 |
+
history, system_prompt=None, image_token=None
|
676 |
+
):
|
677 |
+
image_token = image_token or IMAGE_TOKEN
|
678 |
+
conversations = []
|
679 |
+
image_paths = []
|
680 |
+
for i, his in enumerate(history):
|
681 |
+
prompt, response = his
|
682 |
+
last_turn = conversations[-1] if len(conversations) > 0 else None
|
683 |
+
if prompt is not None:
|
684 |
+
if isinstance(prompt, tuple):
|
685 |
+
image_path = prompt[0]
|
686 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
687 |
+
last_turn['content'] += f" {image_token}"
|
688 |
+
else:
|
689 |
+
# last_turn None or last_turn['role'] == 'assistant'
|
690 |
+
conversations.append({
|
691 |
+
"role": "user",
|
692 |
+
"content": f"{image_token}"
|
693 |
+
})
|
694 |
+
image_paths.append(image_path)
|
695 |
+
else:
|
696 |
+
assert prompt is not None and isinstance(prompt, str)
|
697 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
698 |
+
last_turn['content'] += f"\n{prompt}"
|
699 |
+
else:
|
700 |
+
conversations.append({
|
701 |
+
"role": "user",
|
702 |
+
"content": prompt,
|
703 |
+
})
|
704 |
+
if response is not None:
|
705 |
+
assert isinstance(response, str)
|
706 |
+
conversations.append({
|
707 |
+
"role": "assistant",
|
708 |
+
"content": response,
|
709 |
+
})
|
710 |
+
|
711 |
+
if conversations[0]['role'] != 'system':
|
712 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
713 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
714 |
+
return conversations, image_paths
|
715 |
+
|
716 |
+
|
717 |
+
|
718 |
+
def gradio_history_to_vision_conversation_prompt_paths(
|
719 |
+
history, system_prompt=None, image_token=None
|
720 |
+
):
|
721 |
+
"""
|
722 |
+
Aggregate gradio history into openai conversations
|
723 |
+
history = [
|
724 |
+
["Hello", "Response"],
|
725 |
+
[(file,), None],
|
726 |
+
]
|
727 |
+
--->
|
728 |
+
[
|
729 |
+
{"role": "user", "content": ...}
|
730 |
+
]
|
731 |
+
"""
|
732 |
+
global MODEL_ENGINE
|
733 |
+
|
734 |
+
conversations, image_paths = gradio_history_to_vision_conversations_paths(
|
735 |
+
history, system_prompt, image_token
|
736 |
+
)
|
737 |
+
# print(f'convo: {json.dumps(conversations, indent=4, ensure_ascii=False)}\n{image_paths=}')
|
738 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
739 |
+
conversations,
|
740 |
+
add_generation_prompt=True
|
741 |
+
)
|
742 |
+
return full_prompt, image_paths, conversations
|
743 |
+
|
744 |
+
|
745 |
+
def is_doc(file_path):
|
746 |
+
is_doc_allowed = file_path.endswith((".pdf", ".docx", ".txt"))
|
747 |
+
return is_doc_allowed
|
748 |
+
|
749 |
+
|
750 |
+
def read_doc(file_path):
|
751 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
|
752 |
+
if file_path.endswith('.pdf'):
|
753 |
+
loader = PyPDFLoader(file_path)
|
754 |
+
elif file_path.endswith('.docx'):
|
755 |
+
loader = Docx2txtLoader(file_path)
|
756 |
+
elif file_path.endswith('.txt'):
|
757 |
+
loader = TextLoader(file_path)
|
758 |
+
texts = loader.load()
|
759 |
+
text = "\n\n".join([t.page_content for t in texts])
|
760 |
+
return text
|
761 |
+
|
762 |
+
|
763 |
+
def doc_file_to_instruct_content(file_path, doc_instruction=None):
|
764 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
765 |
+
content = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=read_doc(file_path))
|
766 |
+
return content
|
767 |
+
|
768 |
+
|
769 |
+
def gradio_history_to_doc_conversation_prompt(
|
770 |
+
history, system_prompt=None, doc_instruction=None,
|
771 |
+
):
|
772 |
+
"""
|
773 |
+
Aggregate gradio history into openai conversations
|
774 |
+
history = [
|
775 |
+
["Hello", "Response"],
|
776 |
+
[(file,), None],
|
777 |
+
]
|
778 |
+
--->
|
779 |
+
[
|
780 |
+
{"role": "user", "content": ...}
|
781 |
+
]
|
782 |
+
"""
|
783 |
+
global MODEL_ENGINE
|
784 |
+
# image_token = image_token or IMAGE_TOKEN
|
785 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
786 |
+
conversations = []
|
787 |
+
image_paths = []
|
788 |
+
for i, his in enumerate(history):
|
789 |
+
prompt, response = his
|
790 |
+
last_turn = conversations[-1] if len(conversations) > 0 else None
|
791 |
+
if prompt is not None:
|
792 |
+
if isinstance(prompt, tuple):
|
793 |
+
file_path = prompt[0]
|
794 |
+
if not is_doc(file_path):
|
795 |
+
raise gr.Error(f'file not doc {file_path}')
|
796 |
+
content = doc_file_to_instruct_content(file_path, doc_instruction)
|
797 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
798 |
+
last_turn['content'] += f"{content}"
|
799 |
+
else:
|
800 |
+
# last_turn None or last_turn['role'] == 'assistant'
|
801 |
+
conversations.append({
|
802 |
+
"role": "user",
|
803 |
+
"content": f"{content}"
|
804 |
+
})
|
805 |
+
else:
|
806 |
+
assert prompt is not None and isinstance(prompt, str)
|
807 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
808 |
+
last_turn['content'] += f"\n{prompt}"
|
809 |
+
else:
|
810 |
+
conversations.append({
|
811 |
+
"role": "user",
|
812 |
+
"content": prompt,
|
813 |
+
})
|
814 |
+
if response is not None:
|
815 |
+
assert isinstance(response, str)
|
816 |
+
conversations.append({
|
817 |
+
"role": "assistant",
|
818 |
+
"content": response,
|
819 |
+
})
|
820 |
+
|
821 |
+
if conversations[0]['role'] != 'system':
|
822 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
823 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
824 |
+
|
825 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
826 |
+
conversations,
|
827 |
+
add_generation_prompt=True
|
828 |
+
)
|
829 |
+
return full_prompt, conversations
|
830 |
+
|
831 |
+
|
832 |
+
def gradio_history_to_vision_doc_conversation_prompt_paths(
|
833 |
+
history, system_prompt=None, image_token=None, doc_instruction=None,
|
834 |
+
):
|
835 |
+
"""
|
836 |
+
Aggregate gradio history into openai conversations
|
837 |
+
history = [
|
838 |
+
["Hello", "Response"],
|
839 |
+
[(file,), None],
|
840 |
+
]
|
841 |
+
--->
|
842 |
+
[
|
843 |
+
{"role": "user", "content": ...}
|
844 |
+
]
|
845 |
+
"""
|
846 |
+
global MODEL_ENGINE
|
847 |
+
image_token = image_token or IMAGE_TOKEN
|
848 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
849 |
+
conversations = []
|
850 |
+
image_paths = []
|
851 |
+
for i, his in enumerate(history):
|
852 |
+
prompt, response = his
|
853 |
+
last_turn = conversations[-1] if len(conversations) > 0 else None
|
854 |
+
if prompt is not None:
|
855 |
+
if isinstance(prompt, tuple):
|
856 |
+
file_path = prompt[0]
|
857 |
+
if is_doc(file_path):
|
858 |
+
content = doc_file_to_instruct_content(file_path, doc_instruction)
|
859 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
860 |
+
last_turn['content'] += f"{content}"
|
861 |
+
else:
|
862 |
+
# last_turn None or last_turn['role'] == 'assistant'
|
863 |
+
conversations.append({
|
864 |
+
"role": "user",
|
865 |
+
"content": f"{content}"
|
866 |
+
})
|
867 |
+
else:
|
868 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
869 |
+
last_turn['content'] += f" {image_token}"
|
870 |
+
else:
|
871 |
+
# last_turn None or last_turn['role'] == 'assistant'
|
872 |
+
conversations.append({
|
873 |
+
"role": "user",
|
874 |
+
"content": f"{image_token}"
|
875 |
+
})
|
876 |
+
image_paths.append(file_path)
|
877 |
+
else:
|
878 |
+
assert prompt is not None and isinstance(prompt, str)
|
879 |
+
if last_turn is not None and last_turn['role'] == 'user':
|
880 |
+
last_turn['content'] += f"\n{prompt}"
|
881 |
+
else:
|
882 |
+
conversations.append({
|
883 |
+
"role": "user",
|
884 |
+
"content": prompt,
|
885 |
+
})
|
886 |
+
if response is not None:
|
887 |
+
assert isinstance(response, str)
|
888 |
+
conversations.append({
|
889 |
+
"role": "assistant",
|
890 |
+
"content": response,
|
891 |
+
})
|
892 |
+
|
893 |
+
if conversations[0]['role'] != 'system':
|
894 |
+
system_prompt = system_prompt or SYSTEM_PROMPT
|
895 |
+
conversations = [{"role": "system", "content": system_prompt}] + conversations
|
896 |
+
|
897 |
+
full_prompt = MODEL_ENGINE.apply_chat_template(
|
898 |
+
conversations,
|
899 |
+
add_generation_prompt=True
|
900 |
+
)
|
901 |
+
return full_prompt, image_paths, conversations
|
902 |
+
|
903 |
+
|
904 |
+
def vision_chat_response_stream_multiturn_engine(
|
905 |
+
history: List[Tuple[str, str]],
|
906 |
+
temperature: float,
|
907 |
+
max_tokens: int,
|
908 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
909 |
+
image_token: Optional[str] = IMAGE_TOKEN,
|
910 |
+
):
|
911 |
+
global MODEL_ENGINE
|
912 |
+
temperature = float(temperature)
|
913 |
+
# ! remove frequency_penalty
|
914 |
+
# frequency_penalty = float(frequency_penalty)
|
915 |
+
max_tokens = int(max_tokens)
|
916 |
+
# ! skip safety
|
917 |
+
if DATETIME_FORMAT in system_prompt:
|
918 |
+
# ! This sometime works sometimes dont
|
919 |
+
system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
|
920 |
+
# ! history now can have multimodal
|
921 |
+
|
922 |
+
full_prompt, image_paths, conversations = gradio_history_to_vision_conversation_prompt_paths(
|
923 |
+
history=history, system_prompt=system_prompt, image_token=image_token
|
924 |
+
)
|
925 |
+
|
926 |
+
if hasattr(MODEL_ENGINE, "get_multimodal_tokens"):
|
927 |
+
num_tokens = MODEL_ENGINE.get_multimodal_tokens(full_prompt, image_paths=image_paths)
|
928 |
+
else:
|
929 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
|
930 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
931 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
932 |
+
|
933 |
+
print(f'{image_paths=}')
|
934 |
+
print(full_prompt)
|
935 |
+
outputs = None
|
936 |
+
response = None
|
937 |
+
num_tokens = -1
|
938 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
939 |
+
prompt=full_prompt,
|
940 |
+
temperature=temperature,
|
941 |
+
max_tokens=max_tokens,
|
942 |
+
image_paths=image_paths,
|
943 |
+
)):
|
944 |
+
if isinstance(outputs, tuple):
|
945 |
+
response, num_tokens = outputs
|
946 |
+
else:
|
947 |
+
response, num_tokens = outputs, -1
|
948 |
+
yield response, num_tokens
|
949 |
+
|
950 |
+
if response is not None:
|
951 |
+
yield response, num_tokens
|
952 |
+
|
953 |
+
|
954 |
+
def doc_chat_response_stream_multiturn_engine(
|
955 |
+
history: List[Tuple[str, str]],
|
956 |
+
temperature: float,
|
957 |
+
max_tokens: int,
|
958 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
959 |
+
doc_instruction: Optional[str] = DOC_INSTRUCTION,
|
960 |
+
):
|
961 |
+
global MODEL_ENGINE
|
962 |
+
temperature = float(temperature)
|
963 |
+
# ! remove frequency_penalty
|
964 |
+
# frequency_penalty = float(frequency_penalty)
|
965 |
+
max_tokens = int(max_tokens)
|
966 |
+
# ! skip safety
|
967 |
+
if DATETIME_FORMAT in system_prompt:
|
968 |
+
# ! This sometime works sometimes dont
|
969 |
+
system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
|
970 |
+
# ! history now can have multimodal
|
971 |
+
|
972 |
+
full_prompt, conversations = gradio_history_to_doc_conversation_prompt(
|
973 |
+
history=history, system_prompt=system_prompt, doc_instruction=doc_instruction
|
974 |
+
)
|
975 |
+
|
976 |
+
# ! length checked
|
977 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
|
978 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
979 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
980 |
+
|
981 |
+
print(full_prompt)
|
982 |
+
outputs = None
|
983 |
+
response = None
|
984 |
+
num_tokens = -1
|
985 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
986 |
+
prompt=full_prompt,
|
987 |
+
temperature=temperature,
|
988 |
+
max_tokens=max_tokens,
|
989 |
+
# image_paths=image_paths,
|
990 |
+
)):
|
991 |
+
if isinstance(outputs, tuple):
|
992 |
+
response, num_tokens = outputs
|
993 |
+
else:
|
994 |
+
response, num_tokens = outputs, -1
|
995 |
+
yield response, num_tokens
|
996 |
+
|
997 |
+
if response is not None:
|
998 |
+
yield response, num_tokens
|
999 |
+
|
1000 |
+
|
1001 |
+
|
1002 |
+
|
1003 |
+
def vision_doc_chat_response_stream_multiturn_engine(
|
1004 |
+
history: List[Tuple[str, str]],
|
1005 |
+
temperature: float,
|
1006 |
+
max_tokens: int,
|
1007 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
1008 |
+
image_token: Optional[str] = IMAGE_TOKEN,
|
1009 |
+
doc_instruction: Optional[str] = DOC_INSTRUCTION,
|
1010 |
+
):
|
1011 |
+
global MODEL_ENGINE
|
1012 |
+
temperature = float(temperature)
|
1013 |
+
# ! remove frequency_penalty
|
1014 |
+
# frequency_penalty = float(frequency_penalty)
|
1015 |
+
max_tokens = int(max_tokens)
|
1016 |
+
# ! skip safety
|
1017 |
+
if DATETIME_FORMAT in system_prompt:
|
1018 |
+
# ! This sometime works sometimes dont
|
1019 |
+
system_prompt = system_prompt.format(cur_datetime=get_datetime_string())
|
1020 |
+
# ! history now can have multimodal
|
1021 |
+
|
1022 |
+
full_prompt, image_paths, conversations = gradio_history_to_vision_doc_conversation_prompt_paths(
|
1023 |
+
history=history, system_prompt=system_prompt, image_token=image_token, doc_instruction=doc_instruction
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
# ! length check
|
1027 |
+
if hasattr(MODEL_ENGINE, "get_multimodal_tokens"):
|
1028 |
+
num_tokens = MODEL_ENGINE.get_multimodal_tokens(full_prompt, image_paths=image_paths)
|
1029 |
+
else:
|
1030 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(full_prompt))
|
1031 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
1032 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
1033 |
+
|
1034 |
+
print(full_prompt)
|
1035 |
+
print(f'{image_paths=}')
|
1036 |
+
outputs = None
|
1037 |
+
response = None
|
1038 |
+
num_tokens = -1
|
1039 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
1040 |
+
prompt=full_prompt,
|
1041 |
+
temperature=temperature,
|
1042 |
+
max_tokens=max_tokens,
|
1043 |
+
image_paths=image_paths,
|
1044 |
+
)):
|
1045 |
+
if isinstance(outputs, tuple):
|
1046 |
+
response, num_tokens = outputs
|
1047 |
+
else:
|
1048 |
+
response, num_tokens = outputs, -1
|
1049 |
+
yield response, num_tokens
|
1050 |
+
|
1051 |
+
if response is not None:
|
1052 |
+
yield response, num_tokens
|
1053 |
+
|
1054 |
+
|
1055 |
+
|
1056 |
+
@register_demo
|
1057 |
+
class VisionChatInterfaceDemo(ChatInterfaceDemo):
|
1058 |
+
"""
|
1059 |
+
Accept vision image
|
1060 |
+
"""
|
1061 |
+
|
1062 |
+
@property
|
1063 |
+
def tab_name(self):
|
1064 |
+
return "Vision Chat"
|
1065 |
+
|
1066 |
+
@property
|
1067 |
+
def examples(self):
|
1068 |
+
return [
|
1069 |
+
["What's strange about this image?", "assets/dog_monalisa.jpeg",],
|
1070 |
+
["Explain why the sky is blue.", None,],
|
1071 |
+
]
|
1072 |
+
|
1073 |
+
def create_demo(
|
1074 |
+
self,
|
1075 |
+
title: str | None = None,
|
1076 |
+
description: str | None = None,
|
1077 |
+
**kwargs
|
1078 |
+
) -> gr.Blocks:
|
1079 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
1080 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
1081 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
1082 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
1083 |
+
description = description or """Upload an image to ask question about it."""
|
1084 |
+
|
1085 |
+
def add_multimodal_fn() -> List[Component]:
|
1086 |
+
image_input = gr.Image(label="Input Image", type="filepath", )
|
1087 |
+
return [image_input]
|
1088 |
+
|
1089 |
+
additional_inputs = [
|
1090 |
+
gr.Number(value=temperature, label='Temperature', min_width=20),
|
1091 |
+
gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
|
1092 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=1),
|
1093 |
+
gr.Textbox(value=IMAGE_TOKEN, label='Visual token', lines=1, interactive=IMAGE_TOKEN_INTERACTIVE, min_width=20),
|
1094 |
+
]
|
1095 |
+
def render_additional_inputs_fn():
|
1096 |
+
with Row():
|
1097 |
+
additional_inputs[0].render()
|
1098 |
+
additional_inputs[1].render()
|
1099 |
+
additional_inputs[3].render()
|
1100 |
+
additional_inputs[2].render()
|
1101 |
+
|
1102 |
+
demo_chat = MultiModalChatInterface(
|
1103 |
+
vision_chat_response_stream_multiturn_engine,
|
1104 |
+
chatbot=gr.Chatbot(
|
1105 |
+
label=model_name,
|
1106 |
+
bubble_full_width=False,
|
1107 |
+
latex_delimiters=[
|
1108 |
+
{ "left": "$", "right": "$", "display": False},
|
1109 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1110 |
+
],
|
1111 |
+
show_copy_button=True,
|
1112 |
+
layout="panel" if USE_PANEL else "bubble",
|
1113 |
+
height=CHATBOT_HEIGHT,
|
1114 |
+
),
|
1115 |
+
# textbox=gr.Textbox(placeholder='Type message', lines=4, max_lines=128, min_width=200),
|
1116 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
1117 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1118 |
+
# ! consider preventing the stop button
|
1119 |
+
# stop_btn=None,
|
1120 |
+
add_multimodal_fn=add_multimodal_fn,
|
1121 |
+
title=title,
|
1122 |
+
description=description,
|
1123 |
+
additional_inputs=additional_inputs,
|
1124 |
+
render_additional_inputs_fn=render_additional_inputs_fn,
|
1125 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1126 |
+
examples=self.examples,
|
1127 |
+
cache_examples=False,
|
1128 |
+
css=CSS,
|
1129 |
+
)
|
1130 |
+
return demo_chat
|
1131 |
+
|
1132 |
+
|
1133 |
+
def add_document_upload():
|
1134 |
+
file_input = gr.File(label='Upload pdf, docx, txt', file_count='single', file_types=['pdf', 'docx', 'txt'])
|
1135 |
+
# with Group():
|
1136 |
+
# file_input = gr.Textbox(value=None, label='Document path', lines=1, interactive=False)
|
1137 |
+
# upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt'], file_count="single")
|
1138 |
+
# upload_button.upload(lambda x: x.name, upload_button, file_input)
|
1139 |
+
return file_input
|
1140 |
+
|
1141 |
+
|
1142 |
+
@register_demo
|
1143 |
+
class DocChatInterfaceDemo(ChatInterfaceDemo):
|
1144 |
+
"""
|
1145 |
+
Accept document (full length no RAG)
|
1146 |
+
"""
|
1147 |
+
@property
|
1148 |
+
def tab_name(self):
|
1149 |
+
return "Doc Chat"
|
1150 |
+
|
1151 |
+
@property
|
1152 |
+
def examples(self):
|
1153 |
+
return [
|
1154 |
+
["Summarize the document", "assets/attention_short.pdf",],
|
1155 |
+
["Explain why the sky is blue.", None,],
|
1156 |
+
]
|
1157 |
+
|
1158 |
+
def create_demo(
|
1159 |
+
self,
|
1160 |
+
title: str | None = None,
|
1161 |
+
description: str | None = None,
|
1162 |
+
**kwargs
|
1163 |
+
) -> gr.Blocks:
|
1164 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
1165 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
1166 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
1167 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
1168 |
+
# frequence_penalty = FREQUENCE_PENALTY
|
1169 |
+
# presence_penalty = PRESENCE_PENALTY
|
1170 |
+
description = description or """Upload a short document to ask question about it."""
|
1171 |
+
|
1172 |
+
def add_multimodal_fn() -> List[Component]:
|
1173 |
+
file_input = add_document_upload()
|
1174 |
+
# image_input = gr.Image(label="Input Image", type="filepath", )
|
1175 |
+
return [file_input]
|
1176 |
+
|
1177 |
+
additional_inputs = [
|
1178 |
+
gr.Number(value=temperature, label='Temperature', min_width=20),
|
1179 |
+
gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
|
1180 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=1),
|
1181 |
+
gr.Textbox(value=DOC_INSTRUCTION, label='Doc instruction', lines=1),
|
1182 |
+
]
|
1183 |
+
def render_additional_inputs_fn():
|
1184 |
+
with Row():
|
1185 |
+
additional_inputs[0].render()
|
1186 |
+
additional_inputs[1].render()
|
1187 |
+
additional_inputs[2].render()
|
1188 |
+
additional_inputs[3].render()
|
1189 |
+
|
1190 |
+
demo_chat = MultiModalChatInterface(
|
1191 |
+
doc_chat_response_stream_multiturn_engine,
|
1192 |
+
chatbot=gr.Chatbot(
|
1193 |
+
label=model_name,
|
1194 |
+
bubble_full_width=False,
|
1195 |
+
latex_delimiters=[
|
1196 |
+
{ "left": "$", "right": "$", "display": False},
|
1197 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1198 |
+
],
|
1199 |
+
show_copy_button=True,
|
1200 |
+
layout="panel" if USE_PANEL else "bubble",
|
1201 |
+
height=CHATBOT_HEIGHT,
|
1202 |
+
),
|
1203 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
1204 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1205 |
+
# ! consider preventing the stop button
|
1206 |
+
add_multimodal_fn=add_multimodal_fn,
|
1207 |
+
title=title,
|
1208 |
+
description=description,
|
1209 |
+
additional_inputs=additional_inputs,
|
1210 |
+
render_additional_inputs_fn=render_additional_inputs_fn,
|
1211 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1212 |
+
examples=self.examples,
|
1213 |
+
cache_examples=False,
|
1214 |
+
css=CSS,
|
1215 |
+
)
|
1216 |
+
return demo_chat
|
1217 |
+
|
1218 |
+
|
1219 |
+
@register_demo
|
1220 |
+
class VisionDocChatInterfaceDemo(ChatInterfaceDemo):
|
1221 |
+
"""
|
1222 |
+
Accept either vision image or document (full length no RAG)
|
1223 |
+
"""
|
1224 |
+
@property
|
1225 |
+
def tab_name(self):
|
1226 |
+
return "Vision Doc Chat"
|
1227 |
+
|
1228 |
+
@property
|
1229 |
+
def examples(self):
|
1230 |
+
return [
|
1231 |
+
["What's strange about this image?", None, "assets/dog_monalisa.jpeg",],
|
1232 |
+
["Summarize the document", "assets/attention_short.pdf", None,],
|
1233 |
+
["Explain why the sky is blue.", None, None],
|
1234 |
+
]
|
1235 |
+
|
1236 |
+
def create_demo(
|
1237 |
+
self,
|
1238 |
+
title: str | None = None,
|
1239 |
+
description: str | None = None,
|
1240 |
+
**kwargs
|
1241 |
+
) -> gr.Blocks:
|
1242 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
1243 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
1244 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
1245 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
1246 |
+
# frequence_penalty = FREQUENCE_PENALTY
|
1247 |
+
# presence_penalty = PRESENCE_PENALTY
|
1248 |
+
description = description or """Upload either an image or short document to ask question about it."""
|
1249 |
+
|
1250 |
+
def add_multimodal_fn() -> List[Component]:
|
1251 |
+
file_input = add_document_upload()
|
1252 |
+
image_input = gr.Image(label="Input Image", type="filepath", )
|
1253 |
+
return [file_input, image_input]
|
1254 |
+
|
1255 |
+
additional_inputs = [
|
1256 |
+
gr.Number(value=temperature, label='Temperature', min_width=20),
|
1257 |
+
gr.Number(value=max_tokens, label='Max-tokens', min_width=20),
|
1258 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=1),
|
1259 |
+
gr.Textbox(value=IMAGE_TOKEN, label='Visual token', lines=1, interactive=IMAGE_TOKEN_INTERACTIVE, min_width=2),
|
1260 |
+
gr.Textbox(value=DOC_INSTRUCTION, label='Doc instruction', lines=1),
|
1261 |
+
]
|
1262 |
+
def render_additional_inputs_fn():
|
1263 |
+
with Row():
|
1264 |
+
additional_inputs[0].render()
|
1265 |
+
additional_inputs[1].render()
|
1266 |
+
additional_inputs[3].render()
|
1267 |
+
additional_inputs[2].render()
|
1268 |
+
additional_inputs[4].render()
|
1269 |
+
|
1270 |
+
demo_chat = MultiModalChatInterface(
|
1271 |
+
vision_doc_chat_response_stream_multiturn_engine,
|
1272 |
+
chatbot=gr.Chatbot(
|
1273 |
+
label=MODEL_NAME,
|
1274 |
+
bubble_full_width=False,
|
1275 |
+
latex_delimiters=[
|
1276 |
+
{ "left": "$", "right": "$", "display": False},
|
1277 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1278 |
+
],
|
1279 |
+
show_copy_button=True,
|
1280 |
+
layout="panel" if USE_PANEL else "bubble",
|
1281 |
+
height=CHATBOT_HEIGHT,
|
1282 |
+
),
|
1283 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
1284 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1285 |
+
add_multimodal_fn=add_multimodal_fn,
|
1286 |
+
title=title,
|
1287 |
+
description=description,
|
1288 |
+
additional_inputs=additional_inputs,
|
1289 |
+
render_additional_inputs_fn=render_additional_inputs_fn,
|
1290 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1291 |
+
examples=self.examples,
|
1292 |
+
cache_examples=False,
|
1293 |
+
css=CSS,
|
1294 |
+
)
|
1295 |
+
return demo_chat
|
multipurpose_chatbot/demos/multimodal_preference_interface.py
ADDED
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
from gradio.components.base import Component
|
25 |
+
|
26 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
27 |
+
|
28 |
+
|
29 |
+
from .chat_interface import (
|
30 |
+
SYSTEM_PROMPT,
|
31 |
+
MODEL_NAME,
|
32 |
+
MAX_TOKENS,
|
33 |
+
TEMPERATURE,
|
34 |
+
CHAT_EXAMPLES,
|
35 |
+
gradio_history_to_openai_conversations,
|
36 |
+
gradio_history_to_conversation_prompt,
|
37 |
+
DATETIME_FORMAT,
|
38 |
+
get_datetime_string,
|
39 |
+
chat_response_stream_multiturn_engine,
|
40 |
+
ChatInterfaceDemo,
|
41 |
+
CustomizedChatInterface,
|
42 |
+
)
|
43 |
+
|
44 |
+
from gradio.events import Events
|
45 |
+
|
46 |
+
import inspect
|
47 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
48 |
+
|
49 |
+
import anyio
|
50 |
+
from gradio_client import utils as client_utils
|
51 |
+
from gradio_client.documentation import document
|
52 |
+
|
53 |
+
from gradio.blocks import Blocks
|
54 |
+
from gradio.components import (
|
55 |
+
Button,
|
56 |
+
Chatbot,
|
57 |
+
Component,
|
58 |
+
Markdown,
|
59 |
+
State,
|
60 |
+
Textbox,
|
61 |
+
get_component_instance,
|
62 |
+
)
|
63 |
+
from gradio.events import Dependency, on
|
64 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
65 |
+
from gradio.helpers import special_args
|
66 |
+
from gradio.layouts import Accordion, Group, Row
|
67 |
+
from gradio.routes import Request
|
68 |
+
from gradio.themes import ThemeClass as Theme
|
69 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
70 |
+
|
71 |
+
from ..globals import MODEL_ENGINE
|
72 |
+
|
73 |
+
from ..configs import (
|
74 |
+
USE_PANEL,
|
75 |
+
IMAGE_TOKEN,
|
76 |
+
IMAGE_TOKEN_INTERACTIVE,
|
77 |
+
CHATBOT_HEIGHT,
|
78 |
+
ALLOWED_PATHS,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
from .multimodal_chat_interface import (
|
83 |
+
DOC_INSTRUCTION,
|
84 |
+
DOC_TEMPLATE,
|
85 |
+
CSS,
|
86 |
+
undo_history,
|
87 |
+
undo_history_until_last_assistant_turn,
|
88 |
+
MultiModalChatInterface,
|
89 |
+
gradio_history_to_conversation_prompt,
|
90 |
+
gradio_history_to_openai_conversations,
|
91 |
+
gradio_history_to_vision_conversation_prompt_paths,
|
92 |
+
gradio_history_to_doc_conversation_prompt,
|
93 |
+
gradio_history_to_vision_doc_conversation_prompt_paths,
|
94 |
+
VisionChatInterfaceDemo,
|
95 |
+
vision_chat_response_stream_multiturn_engine,
|
96 |
+
)
|
97 |
+
|
98 |
+
import glob
|
99 |
+
from pathlib import Path
|
100 |
+
from gradio import utils as gradio_utils
|
101 |
+
|
102 |
+
PREF_DIR = os.environ.get("PREF_DIR", "./tmp")
|
103 |
+
PREFERENCE_MAKE_DATA_PATH = os.environ.get("PREFERENCE_MAKE_DATA_PATH", "assets/example_pref.json")
|
104 |
+
|
105 |
+
IMAGE_DIR = os.environ.get("IMAGE_DIR", "./tmp_image")
|
106 |
+
|
107 |
+
EXAMPLE_IMAGE_PATHS = [
|
108 |
+
x
|
109 |
+
for x in glob.glob(os.path.join(IMAGE_DIR, "*"))
|
110 |
+
]
|
111 |
+
print(f'IMAGES: {EXAMPLE_IMAGE_PATHS[:3]=}')
|
112 |
+
|
113 |
+
|
114 |
+
# ! Existing images
|
115 |
+
|
116 |
+
IMAGE_GLOB_ROOT = "/mnt/workspace/workgroup/phi/raw_data/multimodal_seallm/processed/sft/dpo_examples"
|
117 |
+
# ALLOWED_PATHS.append(IMAGE_GLOB_ROOT)
|
118 |
+
IMAGE_GLOBS = {
|
119 |
+
# "geometry": "geo3k/train/*/img_diagram.png",
|
120 |
+
"Geometry": ["geoqa_plus/*png", "Ask question about to solve the puzzle, calculating angles, find values, ... Provide extra information in the question (e.g 'Angle 1 = 30 degrees, find angle 2 from image.')"],
|
121 |
+
"Everyday": ["gqa/images/*", "Ask question to (1) describe, (2) find details, (3) negation (e.g 'Where's the cat?' while there is no cat in image.), (4) write stories ...."],
|
122 |
+
"OCR (read text)": ["ocr_vqa/images/*", "Ask question (1) full OCR description, (2) read specific details (e.g 'Who wrote the book?')."],
|
123 |
+
"OpenViVQA": ["OpenViVQA/training-images/*", "Only vietnamese, (1) full OCR description, (2) read specific details, (3) image description and question answering"],
|
124 |
+
"Text-VQA": ["textvqa/train_images/*", "Ask question to (1) describe, (2) find details, (3) negation (e.g 'Where's the cat?' while there is no cat in image.), (4) write stories, (5) reasoning"],
|
125 |
+
"Landmarks": ["web-landmark/images/*", "Ask question to (1) Where is landmarks (2) What to do at that place (3) Write stories, (4) give advise for tourists..."],
|
126 |
+
"Everyday-VG2": ["vg/VG_100K_2/*", "Same with Everyday"],
|
127 |
+
}
|
128 |
+
|
129 |
+
IMAGE_CUT_OFF_BEGIN = 0
|
130 |
+
IMAGE_CUT_OFF = 100
|
131 |
+
# IMAGE_CUT_OFF = 20
|
132 |
+
|
133 |
+
IMAGE_GLOB_PATHS = {}
|
134 |
+
IMAGE_GLOB_DESCS = {}
|
135 |
+
for k, v in IMAGE_GLOBS.items():
|
136 |
+
glob_p, description = v
|
137 |
+
paths = []
|
138 |
+
for i, p in enumerate(glob.glob(os.path.join(IMAGE_GLOB_ROOT, glob_p))):
|
139 |
+
if i < IMAGE_CUT_OFF_BEGIN:
|
140 |
+
continue
|
141 |
+
if i >= IMAGE_CUT_OFF + IMAGE_CUT_OFF_BEGIN:
|
142 |
+
break
|
143 |
+
paths.append(p)
|
144 |
+
IMAGE_GLOB_PATHS[k] = paths
|
145 |
+
IMAGE_GLOB_DESCS[k] = description
|
146 |
+
|
147 |
+
print(IMAGE_GLOB_PATHS['Geometry'][:10])
|
148 |
+
|
149 |
+
|
150 |
+
def read_json(json_file):
|
151 |
+
print(f'Reading : {json_file}')
|
152 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
153 |
+
rows = json.load(f)
|
154 |
+
return rows
|
155 |
+
|
156 |
+
|
157 |
+
def write_json(data, json_file):
|
158 |
+
with open(json_file, 'w', encoding='utf-8') as f:
|
159 |
+
json.dump(data, f, indent=4, ensure_ascii=False)
|
160 |
+
|
161 |
+
|
162 |
+
def convert_pref_data_to_openai_format(rows_dict):
|
163 |
+
for key, r in rows_dict.items():
|
164 |
+
if "conversation_prefix" in r:
|
165 |
+
assert "responses" in r, f'invalid: {r}'
|
166 |
+
continue
|
167 |
+
history = r['history']
|
168 |
+
conversations = []
|
169 |
+
for user, assistant in history:
|
170 |
+
conversations.append({"role": "user", "content": user.strip()})
|
171 |
+
conversations.append({"role": "assistant", "content": assistant.strip()})
|
172 |
+
r['conversation_prefix'] = conversations[:-1]
|
173 |
+
r['responses'] = [conversations[-1]]
|
174 |
+
r['original_response'] = conversations[-1]
|
175 |
+
if "lang" not in r:
|
176 |
+
r['lang'] = key[-2:]
|
177 |
+
# missing an item in responses
|
178 |
+
lang_set = list(set([r['lang'] for r in rows_dict.values()]))
|
179 |
+
return rows_dict, lang_set
|
180 |
+
|
181 |
+
|
182 |
+
def convert_mm_pref_data_to_openai_format(rows_dict):
|
183 |
+
pass
|
184 |
+
|
185 |
+
|
186 |
+
PREFERENCE_RATE_DICT = None
|
187 |
+
LANG_SET = ["en", "vi", "id", 'ms', "th", "zh", 'lo', 'km', 'tl', 'my']
|
188 |
+
if PREFERENCE_MAKE_DATA_PATH is not None and os.path.exists(PREFERENCE_MAKE_DATA_PATH):
|
189 |
+
print(f'Loading {PREFERENCE_MAKE_DATA_PATH}')
|
190 |
+
PREFERENCE_RATE_DICT = read_json(PREFERENCE_MAKE_DATA_PATH)
|
191 |
+
PREFERENCE_RATE_DICT, _LANG_SET = convert_pref_data_to_openai_format(PREFERENCE_RATE_DICT)
|
192 |
+
LANG_SET = LANG_SET + [l for l in _LANG_SET if l not in LANG_SET]
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
@document()
|
199 |
+
class CustomJsonlLogger(gr.FlaggingCallback):
|
200 |
+
def __init__(self):
|
201 |
+
self.num_lines = 0
|
202 |
+
|
203 |
+
def setup(
|
204 |
+
self,
|
205 |
+
components: list[Component],
|
206 |
+
flagging_dir: Union[str, Path],
|
207 |
+
):
|
208 |
+
self.components = components
|
209 |
+
self.flagging_dir = flagging_dir
|
210 |
+
os.makedirs(flagging_dir, exist_ok=True)
|
211 |
+
flagging_dir = self.flagging_dir
|
212 |
+
log_filepath = Path(flagging_dir) / "log.jsonl"
|
213 |
+
if Path(log_filepath).exists():
|
214 |
+
with open(log_filepath, "rb") as f:
|
215 |
+
self.num_lines = sum(1 for _ in f)
|
216 |
+
else:
|
217 |
+
self.num_lines = 0
|
218 |
+
|
219 |
+
def flag(
|
220 |
+
self,
|
221 |
+
flag_data: list[Any],
|
222 |
+
flag_option: str = "",
|
223 |
+
username: Union[str, None] = None,
|
224 |
+
) -> int:
|
225 |
+
import datetime
|
226 |
+
flagging_dir = self.flagging_dir
|
227 |
+
log_filepath = Path(flagging_dir) / "log.jsonl"
|
228 |
+
is_new = not Path(log_filepath).exists()
|
229 |
+
headers = [
|
230 |
+
getattr(component, "label", None) or f"component {idx}"
|
231 |
+
for idx, component in enumerate(self.components)
|
232 |
+
] + [
|
233 |
+
"flag",
|
234 |
+
"username",
|
235 |
+
"timestamp",
|
236 |
+
]
|
237 |
+
|
238 |
+
csv_data = []
|
239 |
+
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
240 |
+
save_dir = Path(
|
241 |
+
flagging_dir
|
242 |
+
) / client_utils.strip_invalid_filename_characters(
|
243 |
+
getattr(component, "label", None) or f"component {idx}"
|
244 |
+
)
|
245 |
+
if gradio_utils.is_update(sample):
|
246 |
+
csv_data.append(str(sample))
|
247 |
+
else:
|
248 |
+
csv_data.append(
|
249 |
+
component.flag(sample, flag_dir=save_dir)
|
250 |
+
if sample is not None
|
251 |
+
else ""
|
252 |
+
)
|
253 |
+
csv_data.append(flag_option)
|
254 |
+
csv_data.append(username if username is not None else "")
|
255 |
+
csv_data.append(str(datetime.datetime.now()))
|
256 |
+
|
257 |
+
json_obj = {}
|
258 |
+
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
259 |
+
save_dir = Path(
|
260 |
+
flagging_dir
|
261 |
+
) / client_utils.strip_invalid_filename_characters(
|
262 |
+
getattr(component, "label", None) or f"component {idx}"
|
263 |
+
)
|
264 |
+
label = getattr(component, "label", None) or f"component {idx}"
|
265 |
+
if gradio_utils.is_update(sample):
|
266 |
+
value = str(sample)
|
267 |
+
else:
|
268 |
+
value = component.flag(sample, flag_dir=save_dir) if sample is not None else None
|
269 |
+
json_obj[label] = value
|
270 |
+
|
271 |
+
json_obj['flag'] = flag_option
|
272 |
+
json_obj['username'] = username if username is not None else ""
|
273 |
+
json_obj['timestamp'] = str(datetime.datetime.now())
|
274 |
+
|
275 |
+
with open(log_filepath, "a", encoding="utf-8") as jsonl_file:
|
276 |
+
jsonl_file.write(json.dumps(json_obj, ensure_ascii=False) + "\n")
|
277 |
+
|
278 |
+
self.num_lines += 1
|
279 |
+
return self.num_lines
|
280 |
+
|
281 |
+
@document()
|
282 |
+
class VisionJsonlLogger(CustomJsonlLogger):
|
283 |
+
# ! must save the image
|
284 |
+
def flag(
|
285 |
+
self,
|
286 |
+
flag_data: list[Any],
|
287 |
+
flag_option: str = "",
|
288 |
+
username: Union[str, None] = None,
|
289 |
+
) -> int:
|
290 |
+
import datetime
|
291 |
+
from shutil import copyfile
|
292 |
+
flagging_dir = self.flagging_dir
|
293 |
+
log_filepath = Path(flagging_dir) / "log.jsonl"
|
294 |
+
image_dir = Path(flagging_dir) / "images"
|
295 |
+
is_new = not Path(log_filepath).exists()
|
296 |
+
os.makedirs(image_dir, exist_ok=True)
|
297 |
+
headers = [
|
298 |
+
getattr(component, "label", None) or f"component {idx}"
|
299 |
+
for idx, component in enumerate(self.components)
|
300 |
+
] + [
|
301 |
+
"flag",
|
302 |
+
"username",
|
303 |
+
"timestamp",
|
304 |
+
]
|
305 |
+
|
306 |
+
csv_data = []
|
307 |
+
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
308 |
+
save_dir = Path(
|
309 |
+
flagging_dir
|
310 |
+
) / client_utils.strip_invalid_filename_characters(
|
311 |
+
getattr(component, "label", None) or f"component {idx}"
|
312 |
+
)
|
313 |
+
if gradio_utils.is_update(sample):
|
314 |
+
csv_data.append(str(sample))
|
315 |
+
else:
|
316 |
+
csv_data.append(
|
317 |
+
component.flag(sample, flag_dir=save_dir)
|
318 |
+
if sample is not None
|
319 |
+
else ""
|
320 |
+
)
|
321 |
+
csv_data.append(flag_option)
|
322 |
+
csv_data.append(username if username is not None else "")
|
323 |
+
csv_data.append(str(datetime.datetime.now()))
|
324 |
+
|
325 |
+
json_obj = {}
|
326 |
+
for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
|
327 |
+
save_dir = Path(
|
328 |
+
flagging_dir
|
329 |
+
) / client_utils.strip_invalid_filename_characters(
|
330 |
+
getattr(component, "label", None) or f"component {idx}"
|
331 |
+
)
|
332 |
+
label = getattr(component, "label", None) or f"component {idx}"
|
333 |
+
if gradio_utils.is_update(sample):
|
334 |
+
value = str(sample)
|
335 |
+
else:
|
336 |
+
value = component.flag(sample, flag_dir=save_dir) if sample is not None else None
|
337 |
+
if isinstance(value, list):
|
338 |
+
# Expecting history
|
339 |
+
from .multimodal_chat_interface import gradio_history_to_vision_conversations_paths
|
340 |
+
conversations, image_paths = gradio_history_to_vision_conversations_paths(value)
|
341 |
+
new_paths = [
|
342 |
+
os.path.join(image_dir, str(datetime.datetime.now()) + os.path.basename(p))
|
343 |
+
for p in image_paths
|
344 |
+
]
|
345 |
+
for np, ip in zip(new_paths, image_paths):
|
346 |
+
copyfile(ip, np)
|
347 |
+
json_obj[label] = conversations
|
348 |
+
json_obj[label + "-images"] = new_paths
|
349 |
+
else:
|
350 |
+
json_obj[label] = value
|
351 |
+
|
352 |
+
json_obj['flag'] = flag_option
|
353 |
+
json_obj['username'] = username if username is not None else ""
|
354 |
+
json_obj['timestamp'] = str(datetime.datetime.now())
|
355 |
+
|
356 |
+
with open(log_filepath, "a", encoding="utf-8") as jsonl_file:
|
357 |
+
jsonl_file.write(json.dumps(json_obj, ensure_ascii=False) + "\n")
|
358 |
+
|
359 |
+
self.num_lines += 1
|
360 |
+
return self.num_lines
|
361 |
+
|
362 |
+
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
def get_preference_radio():
|
367 |
+
pref_choice = gr.Radio(
|
368 |
+
['1 Better', '2 Better', 'Add best', 'dirty/undecided'],
|
369 |
+
label='preference',
|
370 |
+
info="Indicate if 1 or 2 is better. If both not excellent, pick 'Add best' and write the better one below. If question or answer is problematic, cannot decide, then choose dirty/undecided."
|
371 |
+
)
|
372 |
+
return pref_choice
|
373 |
+
|
374 |
+
|
375 |
+
|
376 |
+
def vision_submit_vision_response_stream_multiturn_engine_yhistory(
|
377 |
+
message: str,
|
378 |
+
input_image: str,
|
379 |
+
history: List[List[str]],
|
380 |
+
temperature: float,
|
381 |
+
max_tokens: int,
|
382 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
383 |
+
image_token: Optional[str] = IMAGE_TOKEN,
|
384 |
+
):
|
385 |
+
# ! Add message and input_image into the history and submit
|
386 |
+
message = message.strip()
|
387 |
+
if message == "":
|
388 |
+
gr.Warning(f'Input text cannot be empty')
|
389 |
+
yield history
|
390 |
+
|
391 |
+
new_history = history
|
392 |
+
if input_image is not None and os.path.exists(input_image):
|
393 |
+
# ! image exist, so add message if it's not empty
|
394 |
+
new_history = new_history + [[(input_image,), None]]
|
395 |
+
if message != "":
|
396 |
+
new_history = new_history + [[message, None]]
|
397 |
+
else:
|
398 |
+
# ! message cannot be empty if there is no input_image
|
399 |
+
if message == "":
|
400 |
+
gr.Warning(f'Input text cannot be empty!')
|
401 |
+
yield history
|
402 |
+
return
|
403 |
+
else:
|
404 |
+
new_history = new_history + [[message, None]]
|
405 |
+
|
406 |
+
yield new_history
|
407 |
+
|
408 |
+
# ! yield current history
|
409 |
+
# use vision_chat_response_stream_multiturn_engine
|
410 |
+
response = None
|
411 |
+
for response, num_tokens in vision_chat_response_stream_multiturn_engine(
|
412 |
+
history=new_history,
|
413 |
+
temperature=temperature, max_tokens=max_tokens, system_prompt=system_prompt,
|
414 |
+
image_token=image_token,
|
415 |
+
):
|
416 |
+
yield new_history[:-1] + [[message, response]]
|
417 |
+
|
418 |
+
if response is not None:
|
419 |
+
yield new_history[:-1] + [[message, response]]
|
420 |
+
|
421 |
+
|
422 |
+
def vision_submit_2_histories(
|
423 |
+
message: str,
|
424 |
+
input_image: str,
|
425 |
+
history1: List[List[str]],
|
426 |
+
history2: List[List[str]],
|
427 |
+
temperature: float,
|
428 |
+
max_tokens: int,
|
429 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
430 |
+
image_token: Optional[str] = IMAGE_TOKEN,
|
431 |
+
):
|
432 |
+
# need to yield 2 history
|
433 |
+
new_history1 = history1
|
434 |
+
new_history2 = history2
|
435 |
+
for his in vision_submit_vision_response_stream_multiturn_engine_yhistory(
|
436 |
+
message, input_image, history1, temperature, max_tokens, system_prompt, image_token,
|
437 |
+
):
|
438 |
+
new_history1 = his
|
439 |
+
yield new_history1, new_history2
|
440 |
+
|
441 |
+
for his in vision_submit_vision_response_stream_multiturn_engine_yhistory(
|
442 |
+
message, input_image, history2, temperature, max_tokens, system_prompt, image_token,
|
443 |
+
):
|
444 |
+
new_history2 = his
|
445 |
+
yield new_history1, new_history2
|
446 |
+
|
447 |
+
|
448 |
+
def undo_history_until_last_assistant_turn_message(history):
|
449 |
+
history = undo_history(history)
|
450 |
+
while len(history) > 0 and history[-1][-1] is None:
|
451 |
+
history = undo_history(history)
|
452 |
+
return history, history
|
453 |
+
|
454 |
+
|
455 |
+
|
456 |
+
def replace_last_response(input_text: str, history: List[Tuple[str, str]]):
|
457 |
+
# replace the last response with input_text
|
458 |
+
input_text = input_text.strip()
|
459 |
+
if input_text == "":
|
460 |
+
gr.Warning(f'prompt empty! dont send empty prompt')
|
461 |
+
return "", history
|
462 |
+
if len(history) == 0:
|
463 |
+
gr.Warning(f'History empty, cannot replace')
|
464 |
+
return input_text, history
|
465 |
+
history[-1][-1] = input_text
|
466 |
+
return "", history
|
467 |
+
|
468 |
+
|
469 |
+
# def load_image_from_gallery(selected_state: gr.SelectData):
|
470 |
+
# convo = sft_data_list[selected_state.index]
|
471 |
+
# dirname = sft_dirname
|
472 |
+
# image_path = os.path.join(dirname, convo['image'])
|
473 |
+
# return image_path
|
474 |
+
|
475 |
+
def load_image_from_gallery(data_list, selected_state: gr.SelectData):
|
476 |
+
image_path = data_list[selected_state.index]
|
477 |
+
# dirname = sft_dirname
|
478 |
+
# image_path = os.path.join(dirname, convo['image'])
|
479 |
+
return image_path
|
480 |
+
|
481 |
+
|
482 |
+
@register_demo
|
483 |
+
class VisionLivePreferencePickDemo(VisionChatInterfaceDemo):
|
484 |
+
@property
|
485 |
+
def examples(self):
|
486 |
+
return [
|
487 |
+
["What's strange about this image?", "assets/dog_monalisa.jpeg",],
|
488 |
+
["Explain why the sky is blue.", None,],
|
489 |
+
]
|
490 |
+
|
491 |
+
@property
|
492 |
+
def tab_name(self):
|
493 |
+
return "Vision Live Preference"
|
494 |
+
|
495 |
+
def create_demo(
|
496 |
+
self,
|
497 |
+
title: str | None = None,
|
498 |
+
description: str | None = None,
|
499 |
+
**kwargs
|
500 |
+
) -> gr.Blocks:
|
501 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
502 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
503 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
504 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
505 |
+
|
506 |
+
log_folder = os.path.join(PREF_DIR, "live_preference_pick")
|
507 |
+
description = f"""
|
508 |
+
## Live generation preference picking
|
509 |
+
Live generation is similar to the Preference Picking demo, except that linguists can come up with questions/prompts **on their own** instead of pre-existing data.
|
510 |
+
|
511 |
+
PREF_DIR: {log_folder}
|
512 |
+
"""
|
513 |
+
|
514 |
+
instruction_content = f"""
|
515 |
+
### Tasks
|
516 |
+
You are enabled to freely build 2 different conversations using the model and pick the better conversations.
|
517 |
+
You can also create best responses if model's generated ones are not good.
|
518 |
+
|
519 |
+
### Requirements
|
520 |
+
The 2 conversations must share at least the first user query. Other than that, the length, number of turns, user queries (except the first one) can vary.
|
521 |
+
For example:
|
522 |
+
```
|
523 |
+
# Valid conversation pairs
|
524 |
+
"User: Hello, 1+1=?" -> "Bot: 1+1=2" -> "User: what about 123+13?" -> "Bot: 123+13=136"
|
525 |
+
-> "Bot: I dont know"
|
526 |
+
|
527 |
+
"User: Hello, 1+1=?" -> "Bot: 1+1=2" -> "User: what about 123+13?" -> "Bot: 123+13=136"
|
528 |
+
-> "Bot: 1+1=3" -> "User: that's wrong!" -> "Bot: Im sorry man."
|
529 |
+
```
|
530 |
+
|
531 |
+
```
|
532 |
+
# Invalid pairs:
|
533 |
+
"User: Hello, 1+1=?" -> "Bot: 1+1=2"
|
534 |
+
"User: Tell me a joke" -> "Bot: here is the joke for your..."
|
535 |
+
```
|
536 |
+
|
537 |
+
### Steps to proceed:
|
538 |
+
There are multiple buttons:
|
539 |
+
* `Submit both`: Submit the text prompt to both chatboxes, expect different (or same) answers.
|
540 |
+
* `Regenerate`: Regenerate the responses of both chatboxes from the last user queries.
|
541 |
+
* `Clear`: Clear both chatboxes.
|
542 |
+
|
543 |
+
The following numbered buttons (1 or 2) is applied to only Bot-1 or Bot-2 respectively.
|
544 |
+
* `Submit-1`: Submit the text prompt only one chatbot (1 or 2).
|
545 |
+
* `Undo-1`: Undo the last generation (both last response and query)
|
546 |
+
* `Regen-1`: Regenerate the last response.
|
547 |
+
* `Replace-1`: Replace the last response with a better response (in case the last response is incorrect, unsatisfactory)
|
548 |
+
|
549 |
+
"""
|
550 |
+
callback = VisionJsonlLogger()
|
551 |
+
with gr.Blocks(css=CSS) as pdemo:
|
552 |
+
gr.Markdown(description)
|
553 |
+
|
554 |
+
with gr.Accordion(label="Instructions and Guidelines", open=False):
|
555 |
+
gr.Markdown(instruction_content)
|
556 |
+
|
557 |
+
with gr.Accordion(label="Additional input", open=False):
|
558 |
+
temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
|
559 |
+
length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
|
560 |
+
# freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
|
561 |
+
# pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
|
562 |
+
# stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation.', lines=1)
|
563 |
+
system_prompt = gr.Textbox(value=system_prompt, label='system_prompt', lines=1)
|
564 |
+
|
565 |
+
|
566 |
+
with gr.Row():
|
567 |
+
chatbot_1 = gr.Chatbot(
|
568 |
+
[],
|
569 |
+
label="Bot-1",
|
570 |
+
elem_id="chatbot-1",
|
571 |
+
bubble_full_width=False,
|
572 |
+
latex_delimiters=[
|
573 |
+
# { "left": "$", "right": "$", "display": False},
|
574 |
+
{ "left": "$$", "right": "$$", "display": True},
|
575 |
+
],
|
576 |
+
show_copy_button=True,
|
577 |
+
layout="panel" if USE_PANEL else "bubble",
|
578 |
+
height=CHATBOT_HEIGHT,
|
579 |
+
)
|
580 |
+
chatbot_2 = gr.Chatbot(
|
581 |
+
[],
|
582 |
+
label="Bot-2",
|
583 |
+
elem_id="chatbot-2",
|
584 |
+
bubble_full_width=False,
|
585 |
+
latex_delimiters=[
|
586 |
+
# { "left": "$", "right": "$", "display": False},
|
587 |
+
{ "left": "$$", "right": "$$", "display": True},
|
588 |
+
],
|
589 |
+
show_copy_button=True,
|
590 |
+
layout="panel" if USE_PANEL else "bubble",
|
591 |
+
height=CHATBOT_HEIGHT,
|
592 |
+
)
|
593 |
+
|
594 |
+
with gr.Row():
|
595 |
+
input_text = gr.Textbox(
|
596 |
+
scale=6,
|
597 |
+
lines=12,
|
598 |
+
# lines=4,
|
599 |
+
max_lines=40,
|
600 |
+
show_label=False,
|
601 |
+
placeholder="Enter text and press enter, or upload an image",
|
602 |
+
container=False,
|
603 |
+
)
|
604 |
+
# submit will submit the same input text to both responses
|
605 |
+
input_image = gr.Image(
|
606 |
+
label="input_image", type="filepath", scale=3,
|
607 |
+
# height=250,
|
608 |
+
)
|
609 |
+
with gr.Row():
|
610 |
+
gen_submit = gr.Button('Send both', scale=1, variant='primary')
|
611 |
+
# regenerate should not care about input_text, it just undo the previous history
|
612 |
+
# regen_submit = gr.Button('Regenerate', scale=1)
|
613 |
+
clear_btn = gr.Button('Clear', scale=1)
|
614 |
+
# submit
|
615 |
+
with gr.Row():
|
616 |
+
chat1_submit = gr.Button('Send-1', variant='primary')
|
617 |
+
chat1_undo = gr.Button('Undo-1')
|
618 |
+
# chat1_regenerate = gr.Button('Regen-1')
|
619 |
+
chat1_replace = gr.Button('Replace-1')
|
620 |
+
|
621 |
+
chat2_submit = gr.Button('Send-2', variant='primary')
|
622 |
+
chat2_undo = gr.Button('Undo-2')
|
623 |
+
# chat2_regenerate = gr.Button('Regen-2')
|
624 |
+
chat2_replace = gr.Button('Replace-2')
|
625 |
+
gr.Markdown(f'**Do not click `Record Choice` twice with the same data sample!**')
|
626 |
+
with gr.Row():
|
627 |
+
pref_choice = get_preference_radio()
|
628 |
+
|
629 |
+
# with gr.Row():
|
630 |
+
# text_replace = gr.Textbox(
|
631 |
+
# placeholder="If both responses are not good, write a better response here. Only apply to the last response.",
|
632 |
+
# lines=2,
|
633 |
+
# max_lines=30,
|
634 |
+
# scale=6,
|
635 |
+
# label="best_response"
|
636 |
+
# )
|
637 |
+
submit_choice_btn = gr.Button('Record Choice', variant='secondary')
|
638 |
+
|
639 |
+
|
640 |
+
from functools import partial
|
641 |
+
|
642 |
+
with gr.Row():
|
643 |
+
gr.Examples(
|
644 |
+
label="Random images",
|
645 |
+
examples=[[x] for x in EXAMPLE_IMAGE_PATHS],
|
646 |
+
inputs=input_image,
|
647 |
+
cache_examples=False,
|
648 |
+
examples_per_page=100,
|
649 |
+
)
|
650 |
+
|
651 |
+
for k, plist in IMAGE_GLOB_PATHS.items():
|
652 |
+
print(f'{k}: {plist[:5]}')
|
653 |
+
gr.Markdown(f"{k}: {IMAGE_GLOB_DESCS[k]}")
|
654 |
+
gallery = gr.Gallery(
|
655 |
+
label=k,
|
656 |
+
value=plist,
|
657 |
+
allow_preview=False,
|
658 |
+
columns=10,
|
659 |
+
# rows=2,
|
660 |
+
height=250,
|
661 |
+
)
|
662 |
+
def _load_image_from_gallery(selected_state: gr.SelectData):
|
663 |
+
image_path = selected_state.value['image']['path']
|
664 |
+
print(f'Select: {image_path}')
|
665 |
+
return image_path
|
666 |
+
gallery.select(
|
667 |
+
_load_image_from_gallery,
|
668 |
+
# lambda select: plist[select.index],
|
669 |
+
# inputs=,
|
670 |
+
outputs=[input_image],
|
671 |
+
queue=False
|
672 |
+
)
|
673 |
+
|
674 |
+
# ! events for submit choices
|
675 |
+
submit_choice_btn.click(
|
676 |
+
lambda: gr.Button(value="Saving...", interactive=False, variant='stop'),
|
677 |
+
None,
|
678 |
+
submit_choice_btn,
|
679 |
+
queue=False,
|
680 |
+
api_name=False,
|
681 |
+
)
|
682 |
+
visual_feedback = True
|
683 |
+
def flag_method(request: gr.Request, *args):
|
684 |
+
# ! must save the image somewhere
|
685 |
+
try:
|
686 |
+
callback.flag(args)
|
687 |
+
except Exception as e:
|
688 |
+
print(f"Error while flagging: {e}")
|
689 |
+
if visual_feedback:
|
690 |
+
return "Error!"
|
691 |
+
if not visual_feedback:
|
692 |
+
return
|
693 |
+
gr.Info(f'Saving preference sucessful ({args[0]})')
|
694 |
+
time.sleep(1) # to provide enough time for the user to observe button change
|
695 |
+
return gr.Button(value="Record Choice", interactive=True)
|
696 |
+
|
697 |
+
callback.setup([chatbot_1, chatbot_2, pref_choice], log_folder)
|
698 |
+
submit_choice_btn.click(
|
699 |
+
flag_method, [chatbot_1, chatbot_2, pref_choice], submit_choice_btn,
|
700 |
+
preprocess=False, queue=False, api_name=False
|
701 |
+
)
|
702 |
+
|
703 |
+
# ! button evenrs
|
704 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
705 |
+
generate_sub_events_both = [input_text.submit, gen_submit.click]
|
706 |
+
on(
|
707 |
+
generate_sub_events_both,
|
708 |
+
vision_submit_2_histories,
|
709 |
+
[
|
710 |
+
input_text, input_image, chatbot_1, chatbot_2,
|
711 |
+
temp, length, system_prompt
|
712 |
+
],
|
713 |
+
[chatbot_1, chatbot_2],
|
714 |
+
api_name=False,
|
715 |
+
queue=True,
|
716 |
+
).then(
|
717 |
+
lambda mes, img: ("", None),
|
718 |
+
[input_text, input_image],
|
719 |
+
[input_text, input_image],
|
720 |
+
api_name=False,
|
721 |
+
queue=False,
|
722 |
+
)
|
723 |
+
clear_btn.click(
|
724 |
+
lambda c1, c2, txt, img: ([], [], "", None),
|
725 |
+
[chatbot_1, chatbot_2, input_text, input_image],
|
726 |
+
[chatbot_1, chatbot_2, input_text, input_image],
|
727 |
+
api_name=False,
|
728 |
+
queue=True,
|
729 |
+
)
|
730 |
+
chat1_submit.click(
|
731 |
+
vision_submit_vision_response_stream_multiturn_engine_yhistory,
|
732 |
+
[
|
733 |
+
input_text, input_image, chatbot_1,
|
734 |
+
temp, length, system_prompt,
|
735 |
+
],
|
736 |
+
[chatbot_1],
|
737 |
+
api_name=False,
|
738 |
+
queue=True,
|
739 |
+
).then(
|
740 |
+
lambda mes, img: ("", None),
|
741 |
+
[input_text, input_image],
|
742 |
+
[input_text, input_image],
|
743 |
+
api_name=False,
|
744 |
+
queue=False,
|
745 |
+
)
|
746 |
+
chat2_submit.click(
|
747 |
+
vision_submit_vision_response_stream_multiturn_engine_yhistory,
|
748 |
+
[
|
749 |
+
input_text, input_image, chatbot_2,
|
750 |
+
temp, length, system_prompt,
|
751 |
+
],
|
752 |
+
[chatbot_2],
|
753 |
+
api_name=False,
|
754 |
+
queue=True,
|
755 |
+
).then(
|
756 |
+
lambda mes, img: ("", None),
|
757 |
+
[input_text, input_image],
|
758 |
+
[input_text, input_image],
|
759 |
+
api_name=False,
|
760 |
+
queue=False,
|
761 |
+
)
|
762 |
+
chat1_undo.click(
|
763 |
+
undo_history_until_last_assistant_turn,
|
764 |
+
chatbot_1,
|
765 |
+
[chatbot_1, input_text],
|
766 |
+
api_name=False,
|
767 |
+
queue=True,
|
768 |
+
)
|
769 |
+
chat2_undo.click(
|
770 |
+
undo_history_until_last_assistant_turn,
|
771 |
+
chatbot_2,
|
772 |
+
[chatbot_2, input_text],
|
773 |
+
api_name=False,
|
774 |
+
queue=True,
|
775 |
+
)
|
776 |
+
chat1_replace.click(
|
777 |
+
replace_last_response,
|
778 |
+
[input_text, chatbot_1],
|
779 |
+
[input_text, chatbot_1],
|
780 |
+
api_name=False,
|
781 |
+
queue=True,
|
782 |
+
)
|
783 |
+
chat2_replace.click(
|
784 |
+
replace_last_response,
|
785 |
+
[input_text, chatbot_2],
|
786 |
+
[input_text, chatbot_2],
|
787 |
+
api_name=False,
|
788 |
+
queue=True,
|
789 |
+
)
|
790 |
+
|
791 |
+
|
792 |
+
|
793 |
+
|
794 |
+
return pdemo
|
multipurpose_chatbot/demos/rag_chat_interface.py
ADDED
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
from gradio.themes import ThemeClass as Theme
|
25 |
+
|
26 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
27 |
+
|
28 |
+
import inspect
|
29 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
30 |
+
|
31 |
+
import anyio
|
32 |
+
from gradio_client import utils as client_utils
|
33 |
+
from gradio_client.documentation import document
|
34 |
+
|
35 |
+
from gradio.blocks import Blocks
|
36 |
+
from gradio.components import (
|
37 |
+
Button,
|
38 |
+
Chatbot,
|
39 |
+
Component,
|
40 |
+
Markdown,
|
41 |
+
State,
|
42 |
+
Textbox,
|
43 |
+
get_component_instance,
|
44 |
+
)
|
45 |
+
from gradio.events import Dependency, on
|
46 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
47 |
+
from gradio.helpers import special_args
|
48 |
+
from gradio.layouts import Accordion, Group, Row
|
49 |
+
from gradio.routes import Request
|
50 |
+
from gradio.themes import ThemeClass as Theme
|
51 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
52 |
+
|
53 |
+
|
54 |
+
from ..globals import MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, load_embeddings, get_rag_embeddings
|
55 |
+
|
56 |
+
from .chat_interface import (
|
57 |
+
SYSTEM_PROMPT,
|
58 |
+
MODEL_NAME,
|
59 |
+
MAX_TOKENS,
|
60 |
+
TEMPERATURE,
|
61 |
+
CHAT_EXAMPLES,
|
62 |
+
gradio_history_to_openai_conversations,
|
63 |
+
gradio_history_to_conversation_prompt,
|
64 |
+
DATETIME_FORMAT,
|
65 |
+
get_datetime_string,
|
66 |
+
format_conversation,
|
67 |
+
chat_response_stream_multiturn_engine,
|
68 |
+
ChatInterfaceDemo,
|
69 |
+
CustomizedChatInterface,
|
70 |
+
)
|
71 |
+
|
72 |
+
from ..configs import (
|
73 |
+
CHUNK_SIZE,
|
74 |
+
CHUNK_OVERLAP,
|
75 |
+
RAG_EMBED_MODEL_NAME,
|
76 |
+
)
|
77 |
+
|
78 |
+
RAG_CURRENT_VECTORSTORE = None
|
79 |
+
|
80 |
+
|
81 |
+
def load_document_split_vectorstore(file_path):
|
82 |
+
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
83 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
84 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
85 |
+
from langchain_community.vectorstores import Chroma, FAISS
|
86 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
|
87 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
|
88 |
+
if file_path.endswith('.pdf'):
|
89 |
+
loader = PyPDFLoader(file_path)
|
90 |
+
elif file_path.endswith('.docx'):
|
91 |
+
loader = Docx2txtLoader(file_path)
|
92 |
+
elif file_path.endswith('.txt'):
|
93 |
+
loader = TextLoader(file_path)
|
94 |
+
splits = loader.load_and_split(splitter)
|
95 |
+
RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
|
96 |
+
return RAG_CURRENT_VECTORSTORE
|
97 |
+
|
98 |
+
def docs_to_context_content(docs: List[Any]):
|
99 |
+
content = "\n".join([d.page_content for d in docs])
|
100 |
+
return content
|
101 |
+
|
102 |
+
|
103 |
+
DOC_TEMPLATE = """###
|
104 |
+
{content}
|
105 |
+
###
|
106 |
+
|
107 |
+
"""
|
108 |
+
|
109 |
+
DOC_INSTRUCTION = """Answer the following query exclusively based on the information provided in the document above. \
|
110 |
+
If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
|
111 |
+
"""
|
112 |
+
|
113 |
+
|
114 |
+
def docs_to_rag_context(docs: List[Any], doc_instruction=None):
|
115 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
116 |
+
content = docs_to_context_content(docs)
|
117 |
+
context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=content)
|
118 |
+
return context
|
119 |
+
|
120 |
+
|
121 |
+
def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
|
122 |
+
doc_context = None
|
123 |
+
if file_input is not None:
|
124 |
+
if file_input == RAG_CURRENT_FILE:
|
125 |
+
# reuse
|
126 |
+
vectorstore = RAG_CURRENT_VECTORSTORE
|
127 |
+
print(f'Reuse vectorstore: {file_input}')
|
128 |
+
else:
|
129 |
+
vectorstore = load_document_split_vectorstore(file_input)
|
130 |
+
print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
|
131 |
+
RAG_CURRENT_FILE = file_input
|
132 |
+
docs = vectorstore.similarity_search(message, k=rag_num_docs)
|
133 |
+
doc_context = docs_to_rag_context(docs)
|
134 |
+
return doc_context
|
135 |
+
|
136 |
+
|
137 |
+
def chat_response_stream_multiturn_doc_engine(
|
138 |
+
message: str,
|
139 |
+
history: List[Tuple[str, str]],
|
140 |
+
file_input: Optional[str] = None,
|
141 |
+
temperature: float = 0.7,
|
142 |
+
max_tokens: int = 1024,
|
143 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT,
|
144 |
+
rag_num_docs: Optional[int] = 3,
|
145 |
+
doc_instruction: Optional[str] = DOC_INSTRUCTION,
|
146 |
+
# profile: Optional[gr.OAuthProfile] = None,
|
147 |
+
):
|
148 |
+
global MODEL_ENGINE, RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
149 |
+
if len(message) == 0:
|
150 |
+
raise gr.Error("The message cannot be empty!")
|
151 |
+
|
152 |
+
rag_num_docs = int(rag_num_docs)
|
153 |
+
doc_instruction = doc_instruction or DOC_INSTRUCTION
|
154 |
+
doc_context = None
|
155 |
+
if file_input is not None:
|
156 |
+
if file_input == RAG_CURRENT_FILE:
|
157 |
+
# reuse
|
158 |
+
vectorstore = RAG_CURRENT_VECTORSTORE
|
159 |
+
print(f'Reuse vectorstore: {file_input}')
|
160 |
+
else:
|
161 |
+
vectorstore = load_document_split_vectorstore(file_input)
|
162 |
+
print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
|
163 |
+
RAG_CURRENT_FILE = file_input
|
164 |
+
docs = vectorstore.similarity_search(message, k=rag_num_docs)
|
165 |
+
# doc_context = docs_to_rag_context(docs)
|
166 |
+
rag_content = docs_to_context_content(docs)
|
167 |
+
doc_context = doc_instruction.strip() + "\n" + DOC_TEMPLATE.format(content=rag_content)
|
168 |
+
|
169 |
+
if doc_context is not None:
|
170 |
+
message = f"{doc_context}\n\n{message}"
|
171 |
+
|
172 |
+
for response, num_tokens in chat_response_stream_multiturn_engine(
|
173 |
+
message, history, temperature, max_tokens, system_prompt
|
174 |
+
):
|
175 |
+
# ! yield another content which is doc_context
|
176 |
+
yield response, num_tokens, doc_context
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
class RagChatInterface(CustomizedChatInterface):
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
fn: Callable[..., Any],
|
184 |
+
*,
|
185 |
+
chatbot: gr.Chatbot | None = None,
|
186 |
+
textbox: gr.Textbox | None = None,
|
187 |
+
additional_inputs: str | Component | list[str | Component] | None = None,
|
188 |
+
additional_inputs_accordion_name: str | None = None,
|
189 |
+
additional_inputs_accordion: str | gr.Accordion | None = None,
|
190 |
+
render_additional_inputs_fn: Callable | None = None,
|
191 |
+
examples: list[str] | None = None,
|
192 |
+
cache_examples: bool | None = None,
|
193 |
+
title: str | None = None,
|
194 |
+
description: str | None = None,
|
195 |
+
theme: Theme | str | None = None,
|
196 |
+
css: str | None = None,
|
197 |
+
js: str | None = None,
|
198 |
+
head: str | None = None,
|
199 |
+
analytics_enabled: bool | None = None,
|
200 |
+
submit_btn: str | Button | None = "Submit",
|
201 |
+
stop_btn: str | Button | None = "Stop",
|
202 |
+
retry_btn: str | Button | None = "🔄 Retry",
|
203 |
+
undo_btn: str | Button | None = "↩️ Undo",
|
204 |
+
clear_btn: str | Button | None = "🗑️ Clear",
|
205 |
+
autofocus: bool = True,
|
206 |
+
concurrency_limit: int | Literal['default'] | None = "default",
|
207 |
+
fill_height: bool = True
|
208 |
+
):
|
209 |
+
try:
|
210 |
+
super(gr.ChatInterface, self).__init__(
|
211 |
+
analytics_enabled=analytics_enabled,
|
212 |
+
mode="chat_interface",
|
213 |
+
css=css,
|
214 |
+
title=title or "Gradio",
|
215 |
+
theme=theme,
|
216 |
+
js=js,
|
217 |
+
head=head,
|
218 |
+
fill_height=fill_height,
|
219 |
+
)
|
220 |
+
except Exception as e:
|
221 |
+
# Handling some old gradio version with out fill_height
|
222 |
+
super(gr.ChatInterface, self).__init__(
|
223 |
+
analytics_enabled=analytics_enabled,
|
224 |
+
mode="chat_interface",
|
225 |
+
css=css,
|
226 |
+
title=title or "Gradio",
|
227 |
+
theme=theme,
|
228 |
+
js=js,
|
229 |
+
head=head,
|
230 |
+
# fill_height=fill_height,
|
231 |
+
)
|
232 |
+
self.concurrency_limit = concurrency_limit
|
233 |
+
self.fn = fn
|
234 |
+
self.render_additional_inputs_fn = render_additional_inputs_fn
|
235 |
+
self.is_async = inspect.iscoroutinefunction(
|
236 |
+
self.fn
|
237 |
+
) or inspect.isasyncgenfunction(self.fn)
|
238 |
+
self.is_generator = inspect.isgeneratorfunction(
|
239 |
+
self.fn
|
240 |
+
) or inspect.isasyncgenfunction(self.fn)
|
241 |
+
self.examples = examples
|
242 |
+
if self.space_id and cache_examples is None:
|
243 |
+
self.cache_examples = True
|
244 |
+
else:
|
245 |
+
self.cache_examples = cache_examples or False
|
246 |
+
self.buttons: list[Button | None] = []
|
247 |
+
|
248 |
+
if additional_inputs:
|
249 |
+
if not isinstance(additional_inputs, list):
|
250 |
+
additional_inputs = [additional_inputs]
|
251 |
+
self.additional_inputs = [
|
252 |
+
get_component_instance(i)
|
253 |
+
for i in additional_inputs # type: ignore
|
254 |
+
]
|
255 |
+
else:
|
256 |
+
self.additional_inputs = []
|
257 |
+
if additional_inputs_accordion_name is not None:
|
258 |
+
print(
|
259 |
+
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
260 |
+
)
|
261 |
+
self.additional_inputs_accordion_params = {
|
262 |
+
"label": additional_inputs_accordion_name
|
263 |
+
}
|
264 |
+
if additional_inputs_accordion is None:
|
265 |
+
self.additional_inputs_accordion_params = {
|
266 |
+
"label": "Additional Inputs",
|
267 |
+
"open": False,
|
268 |
+
}
|
269 |
+
elif isinstance(additional_inputs_accordion, str):
|
270 |
+
self.additional_inputs_accordion_params = {
|
271 |
+
"label": additional_inputs_accordion
|
272 |
+
}
|
273 |
+
elif isinstance(additional_inputs_accordion, Accordion):
|
274 |
+
self.additional_inputs_accordion_params = (
|
275 |
+
additional_inputs_accordion.recover_kwargs(
|
276 |
+
additional_inputs_accordion.get_config()
|
277 |
+
)
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
raise ValueError(
|
281 |
+
f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
|
282 |
+
)
|
283 |
+
|
284 |
+
with self:
|
285 |
+
if title:
|
286 |
+
Markdown(
|
287 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>"
|
288 |
+
)
|
289 |
+
if description:
|
290 |
+
Markdown(description)
|
291 |
+
|
292 |
+
if chatbot:
|
293 |
+
self.chatbot = chatbot.render()
|
294 |
+
else:
|
295 |
+
self.chatbot = Chatbot(
|
296 |
+
label="Chatbot", scale=1, height=200 if fill_height else None
|
297 |
+
)
|
298 |
+
|
299 |
+
with Row():
|
300 |
+
for btn in [retry_btn, undo_btn, clear_btn]:
|
301 |
+
if btn is not None:
|
302 |
+
if isinstance(btn, Button):
|
303 |
+
btn.render()
|
304 |
+
elif isinstance(btn, str):
|
305 |
+
btn = Button(btn, variant="secondary", size="sm")
|
306 |
+
else:
|
307 |
+
raise ValueError(
|
308 |
+
f"All the _btn parameters must be a gr.Button, string, or None, not {type(btn)}"
|
309 |
+
)
|
310 |
+
self.buttons.append(btn) # type: ignore
|
311 |
+
|
312 |
+
with Group():
|
313 |
+
with Row():
|
314 |
+
if textbox:
|
315 |
+
textbox.container = False
|
316 |
+
textbox.show_label = False
|
317 |
+
textbox_ = textbox.render()
|
318 |
+
assert isinstance(textbox_, Textbox)
|
319 |
+
self.textbox = textbox_
|
320 |
+
else:
|
321 |
+
self.textbox = Textbox(
|
322 |
+
container=False,
|
323 |
+
show_label=False,
|
324 |
+
label="Message",
|
325 |
+
placeholder="Type a message...",
|
326 |
+
scale=7,
|
327 |
+
autofocus=autofocus,
|
328 |
+
)
|
329 |
+
if submit_btn is not None:
|
330 |
+
if isinstance(submit_btn, Button):
|
331 |
+
submit_btn.render()
|
332 |
+
elif isinstance(submit_btn, str):
|
333 |
+
submit_btn = Button(
|
334 |
+
submit_btn,
|
335 |
+
variant="primary",
|
336 |
+
scale=2,
|
337 |
+
min_width=150,
|
338 |
+
)
|
339 |
+
else:
|
340 |
+
raise ValueError(
|
341 |
+
f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
|
342 |
+
)
|
343 |
+
if stop_btn is not None:
|
344 |
+
if isinstance(stop_btn, Button):
|
345 |
+
stop_btn.visible = False
|
346 |
+
stop_btn.render()
|
347 |
+
elif isinstance(stop_btn, str):
|
348 |
+
stop_btn = Button(
|
349 |
+
stop_btn,
|
350 |
+
variant="stop",
|
351 |
+
visible=False,
|
352 |
+
scale=2,
|
353 |
+
min_width=150,
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
raise ValueError(
|
357 |
+
f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
|
358 |
+
)
|
359 |
+
self.num_tokens = Textbox(
|
360 |
+
container=False,
|
361 |
+
label="num_tokens",
|
362 |
+
placeholder="0 tokens",
|
363 |
+
scale=1,
|
364 |
+
interactive=False,
|
365 |
+
# autofocus=autofocus,
|
366 |
+
min_width=10
|
367 |
+
)
|
368 |
+
self.buttons.extend([submit_btn, stop_btn]) # type: ignore
|
369 |
+
|
370 |
+
self.fake_api_btn = Button("Fake API", visible=False)
|
371 |
+
self.fake_response_textbox = Textbox(label="Response", visible=False)
|
372 |
+
(
|
373 |
+
self.retry_btn,
|
374 |
+
self.undo_btn,
|
375 |
+
self.clear_btn,
|
376 |
+
self.submit_btn,
|
377 |
+
self.stop_btn,
|
378 |
+
) = self.buttons
|
379 |
+
|
380 |
+
if examples:
|
381 |
+
if self.is_generator:
|
382 |
+
examples_fn = self._examples_stream_fn
|
383 |
+
else:
|
384 |
+
examples_fn = self._examples_fn
|
385 |
+
|
386 |
+
self.examples_handler = Examples(
|
387 |
+
examples=examples,
|
388 |
+
inputs=[self.textbox] + self.additional_inputs,
|
389 |
+
outputs=self.chatbot,
|
390 |
+
fn=examples_fn,
|
391 |
+
)
|
392 |
+
|
393 |
+
any_unrendered_inputs = any(
|
394 |
+
not inp.is_rendered for inp in self.additional_inputs
|
395 |
+
)
|
396 |
+
if self.additional_inputs and any_unrendered_inputs:
|
397 |
+
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
|
398 |
+
if self.render_additional_inputs_fn is not None:
|
399 |
+
self.render_additional_inputs_fn()
|
400 |
+
else:
|
401 |
+
for input_component in self.additional_inputs:
|
402 |
+
if not input_component.is_rendered:
|
403 |
+
input_component.render()
|
404 |
+
|
405 |
+
self.rag_content = gr.Textbox(
|
406 |
+
scale=4,
|
407 |
+
lines=16,
|
408 |
+
label='Retrieved RAG context',
|
409 |
+
placeholder="Rag context and instrution will show up here",
|
410 |
+
interactive=False
|
411 |
+
)
|
412 |
+
|
413 |
+
# The example caching must happen after the input components have rendered
|
414 |
+
if cache_examples:
|
415 |
+
client_utils.synchronize_async(self.examples_handler.cache)
|
416 |
+
|
417 |
+
self.saved_input = State()
|
418 |
+
self.chatbot_state = (
|
419 |
+
State(self.chatbot.value) if self.chatbot.value else State([])
|
420 |
+
)
|
421 |
+
|
422 |
+
self._setup_events()
|
423 |
+
self._setup_api()
|
424 |
+
|
425 |
+
def _setup_events(self) -> None:
|
426 |
+
from gradio.components import State
|
427 |
+
has_on = False
|
428 |
+
try:
|
429 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
430 |
+
has_on = True
|
431 |
+
except ImportError as ie:
|
432 |
+
has_on = False
|
433 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
434 |
+
if not self.is_generator:
|
435 |
+
raise NotImplementedError(f'should use generator')
|
436 |
+
|
437 |
+
if has_on:
|
438 |
+
# new version
|
439 |
+
submit_triggers = (
|
440 |
+
[self.textbox.submit, self.submit_btn.click]
|
441 |
+
if self.submit_btn
|
442 |
+
else [self.textbox.submit]
|
443 |
+
)
|
444 |
+
submit_event = (
|
445 |
+
on(
|
446 |
+
submit_triggers,
|
447 |
+
self._clear_and_save_textbox,
|
448 |
+
[self.textbox],
|
449 |
+
[self.textbox, self.saved_input],
|
450 |
+
api_name=False,
|
451 |
+
queue=False,
|
452 |
+
)
|
453 |
+
.then(
|
454 |
+
self._display_input,
|
455 |
+
[self.saved_input, self.chatbot_state],
|
456 |
+
[self.chatbot, self.chatbot_state],
|
457 |
+
api_name=False,
|
458 |
+
queue=False,
|
459 |
+
)
|
460 |
+
.then(
|
461 |
+
submit_fn,
|
462 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
463 |
+
[self.chatbot, self.chatbot_state, self.num_tokens, self.rag_content],
|
464 |
+
api_name=False,
|
465 |
+
)
|
466 |
+
)
|
467 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
468 |
+
else:
|
469 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
470 |
+
|
471 |
+
if self.retry_btn:
|
472 |
+
retry_event = (
|
473 |
+
self.retry_btn.click(
|
474 |
+
self._delete_prev_fn,
|
475 |
+
[self.chatbot_state],
|
476 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
477 |
+
api_name=False,
|
478 |
+
queue=False,
|
479 |
+
)
|
480 |
+
.then(
|
481 |
+
self._display_input,
|
482 |
+
[self.saved_input, self.chatbot_state],
|
483 |
+
[self.chatbot, self.chatbot_state],
|
484 |
+
api_name=False,
|
485 |
+
queue=False,
|
486 |
+
)
|
487 |
+
.then(
|
488 |
+
submit_fn,
|
489 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
490 |
+
[self.chatbot, self.chatbot_state, self.num_tokens, self.rag_content],
|
491 |
+
api_name=False,
|
492 |
+
)
|
493 |
+
)
|
494 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
495 |
+
|
496 |
+
if self.undo_btn:
|
497 |
+
self.undo_btn.click(
|
498 |
+
self._delete_prev_fn,
|
499 |
+
[self.chatbot_state],
|
500 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
501 |
+
api_name=False,
|
502 |
+
queue=False,
|
503 |
+
).then(
|
504 |
+
lambda x: x,
|
505 |
+
[self.saved_input],
|
506 |
+
[self.textbox],
|
507 |
+
api_name=False,
|
508 |
+
queue=False,
|
509 |
+
)
|
510 |
+
# Reconfigure clear_btn to stop and clear text box
|
511 |
+
|
512 |
+
async def _stream_fn(
|
513 |
+
self,
|
514 |
+
message: str,
|
515 |
+
history_with_input,
|
516 |
+
request: Request,
|
517 |
+
*args,
|
518 |
+
) -> AsyncGenerator:
|
519 |
+
history = history_with_input[:-1]
|
520 |
+
inputs, _, _ = special_args(
|
521 |
+
self.fn, inputs=[message, history, *args], request=request
|
522 |
+
)
|
523 |
+
|
524 |
+
if self.is_async:
|
525 |
+
generator = self.fn(*inputs)
|
526 |
+
else:
|
527 |
+
generator = await anyio.to_thread.run_sync(
|
528 |
+
self.fn, *inputs, limiter=self.limiter
|
529 |
+
)
|
530 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
531 |
+
|
532 |
+
# ! In case of error, yield the previous history & undo any generation before raising error
|
533 |
+
try:
|
534 |
+
first_response_pack = await async_iteration(generator)
|
535 |
+
if isinstance(first_response_pack, (tuple, list)):
|
536 |
+
first_response, num_tokens, rag_content = first_response_pack
|
537 |
+
else:
|
538 |
+
first_response, num_tokens, rag_content = first_response_pack, -1, ""
|
539 |
+
update = history + [[message, first_response]]
|
540 |
+
yield update, update, f"{num_tokens} toks", rag_content
|
541 |
+
except StopIteration:
|
542 |
+
update = history + [[message, None]]
|
543 |
+
yield update, update, "NaN toks", ""
|
544 |
+
except Exception as e:
|
545 |
+
yield history, history, "NaN toks", ""
|
546 |
+
raise e
|
547 |
+
|
548 |
+
try:
|
549 |
+
async for response_pack in generator:
|
550 |
+
if isinstance(response_pack, (tuple, list)):
|
551 |
+
response, num_tokens, rag_content = response_pack
|
552 |
+
else:
|
553 |
+
response, num_tokens, rag_content = response_pack, "NaN toks", ""
|
554 |
+
update = history + [[message, response]]
|
555 |
+
yield update, update, f"{num_tokens} toks", rag_content
|
556 |
+
except Exception as e:
|
557 |
+
yield history, history, "NaN toks", ""
|
558 |
+
raise e
|
559 |
+
|
560 |
+
|
561 |
+
|
562 |
+
@register_demo
|
563 |
+
class RagChatInterfaceDemo(ChatInterfaceDemo):
|
564 |
+
|
565 |
+
@property
|
566 |
+
def examples(self):
|
567 |
+
return [
|
568 |
+
["Explain how attention works.", "assets/attention_all_you_need.pdf"],
|
569 |
+
["Explain why the sky is blue.", None],
|
570 |
+
]
|
571 |
+
|
572 |
+
@property
|
573 |
+
def tab_name(self):
|
574 |
+
return "RAG Chat"
|
575 |
+
|
576 |
+
def create_demo(
|
577 |
+
self,
|
578 |
+
title: str | None = None,
|
579 |
+
description: str | None = None,
|
580 |
+
**kwargs
|
581 |
+
) -> gr.Blocks:
|
582 |
+
load_embeddings()
|
583 |
+
global RAG_EMBED
|
584 |
+
# assert RAG_EMBED is not None
|
585 |
+
print(F'{RAG_EMBED=}')
|
586 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
587 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
588 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
589 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
590 |
+
rag_num_docs = kwargs.get("rag_num_docs", 3)
|
591 |
+
|
592 |
+
from ..configs import RAG_EMBED_MODEL_NAME
|
593 |
+
|
594 |
+
description = description or f"""Upload a long document to ask question about it with RAG. Embedding model {RAG_EMBED_MODEL_NAME}"""
|
595 |
+
|
596 |
+
additional_inputs = [
|
597 |
+
gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt']),
|
598 |
+
gr.Number(value=temperature, label='Temperature', min_width=20),
|
599 |
+
gr.Number(value=max_tokens, label='Max tokens', min_width=20),
|
600 |
+
gr.Textbox(value=system_prompt, label='System prompt', lines=2),
|
601 |
+
gr.Number(value=rag_num_docs, label='RAG Top-K', min_width=20),
|
602 |
+
gr.Textbox(value=DOC_INSTRUCTION, label='RAG instruction'),
|
603 |
+
]
|
604 |
+
def render_additional_inputs_fn():
|
605 |
+
additional_inputs[0].render()
|
606 |
+
with Row():
|
607 |
+
additional_inputs[1].render()
|
608 |
+
additional_inputs[2].render()
|
609 |
+
additional_inputs[4].render()
|
610 |
+
additional_inputs[3].render()
|
611 |
+
additional_inputs[5].render()
|
612 |
+
|
613 |
+
demo_chat = RagChatInterface(
|
614 |
+
chat_response_stream_multiturn_doc_engine,
|
615 |
+
chatbot=gr.Chatbot(
|
616 |
+
label=model_name,
|
617 |
+
bubble_full_width=False,
|
618 |
+
latex_delimiters=[
|
619 |
+
{ "left": "$", "right": "$", "display": False},
|
620 |
+
{ "left": "$$", "right": "$$", "display": True},
|
621 |
+
],
|
622 |
+
show_copy_button=True,
|
623 |
+
),
|
624 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200, scale=8),
|
625 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
626 |
+
# ! consider preventing the stop button
|
627 |
+
# stop_btn=None,
|
628 |
+
title=title,
|
629 |
+
description=description,
|
630 |
+
additional_inputs=additional_inputs,
|
631 |
+
render_additional_inputs_fn=render_additional_inputs_fn,
|
632 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
633 |
+
examples=self.examples,
|
634 |
+
cache_examples=False,
|
635 |
+
)
|
636 |
+
return demo_chat
|
637 |
+
|
638 |
+
|
multipurpose_chatbot/demos/text_completion.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gradio.themes import ThemeClass as Theme
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
from gradio.components import Button, Component
|
20 |
+
from gradio.events import Dependency, EventListenerMethod
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
|
25 |
+
|
26 |
+
import inspect
|
27 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
28 |
+
|
29 |
+
import anyio
|
30 |
+
from gradio_client import utils as client_utils
|
31 |
+
from gradio_client.documentation import document
|
32 |
+
|
33 |
+
from gradio.blocks import Blocks
|
34 |
+
from gradio.components import (
|
35 |
+
Button,
|
36 |
+
Chatbot,
|
37 |
+
Component,
|
38 |
+
Markdown,
|
39 |
+
State,
|
40 |
+
Textbox,
|
41 |
+
get_component_instance,
|
42 |
+
)
|
43 |
+
from gradio.events import Dependency, on
|
44 |
+
from gradio.helpers import create_examples as Examples # noqa: N812
|
45 |
+
from gradio.helpers import special_args
|
46 |
+
from gradio.layouts import Accordion, Group, Row
|
47 |
+
from gradio.routes import Request
|
48 |
+
from gradio.themes import ThemeClass as Theme
|
49 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
50 |
+
|
51 |
+
|
52 |
+
from .base_demo import register_demo, get_demo_class, BaseDemo
|
53 |
+
|
54 |
+
|
55 |
+
from ..configs import (
|
56 |
+
SYSTEM_PROMPT,
|
57 |
+
MODEL_NAME,
|
58 |
+
MAX_TOKENS,
|
59 |
+
TEMPERATURE,
|
60 |
+
)
|
61 |
+
|
62 |
+
from ..globals import MODEL_ENGINE
|
63 |
+
|
64 |
+
|
65 |
+
def generate_text_completion_stream_engine(
|
66 |
+
message: str,
|
67 |
+
temperature: float,
|
68 |
+
max_tokens: int,
|
69 |
+
stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
|
70 |
+
):
|
71 |
+
global MODEL_ENGINE
|
72 |
+
temperature = float(temperature)
|
73 |
+
# ! remove frequency_penalty
|
74 |
+
# frequency_penalty = float(frequency_penalty)
|
75 |
+
max_tokens = int(max_tokens)
|
76 |
+
# message = message.strip()
|
77 |
+
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
78 |
+
stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>', '<|im_end|>']))
|
79 |
+
if message.strip() != message:
|
80 |
+
gr.Warning(f'There are preceding/trailing spaces in the message, may lead to unexpected behavior')
|
81 |
+
if len(message) == 0:
|
82 |
+
raise gr.Error("The message cannot be empty!")
|
83 |
+
num_tokens = len(MODEL_ENGINE.tokenizer.encode(message))
|
84 |
+
if num_tokens >= MODEL_ENGINE.max_position_embeddings - 128:
|
85 |
+
raise gr.Error(f"Conversation or prompt is too long ({num_tokens} toks), please clear the chatbox or try shorter input.")
|
86 |
+
|
87 |
+
outputs = None
|
88 |
+
response = None
|
89 |
+
num_tokens = -1
|
90 |
+
for j, outputs in enumerate(MODEL_ENGINE.generate_yield_string(
|
91 |
+
prompt=message,
|
92 |
+
temperature=temperature,
|
93 |
+
max_tokens=max_tokens,
|
94 |
+
stop_strings=stop_strings,
|
95 |
+
)):
|
96 |
+
if isinstance(outputs, tuple):
|
97 |
+
response, num_tokens = outputs
|
98 |
+
else:
|
99 |
+
response, num_tokens = outputs, -1
|
100 |
+
yield message + response, f"{num_tokens} tokens"
|
101 |
+
|
102 |
+
if response is not None:
|
103 |
+
yield message + response, f"{num_tokens} tokens"
|
104 |
+
|
105 |
+
|
106 |
+
@register_demo
|
107 |
+
class TextCompletionDemo(BaseDemo):
|
108 |
+
@property
|
109 |
+
def tab_name(self):
|
110 |
+
return "Text Completion"
|
111 |
+
|
112 |
+
def create_demo(
|
113 |
+
self,
|
114 |
+
title: str | None = None,
|
115 |
+
description: str | None = None,
|
116 |
+
**kwargs
|
117 |
+
) -> gr.Blocks:
|
118 |
+
system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT)
|
119 |
+
max_tokens = kwargs.get("max_tokens", MAX_TOKENS)
|
120 |
+
temperature = kwargs.get("temperature", TEMPERATURE)
|
121 |
+
model_name = kwargs.get("model_name", MODEL_NAME)
|
122 |
+
# frequence_penalty = FREQUENCE_PENALTY
|
123 |
+
# presence_penalty = PRESENCE_PENALTY
|
124 |
+
max_tokens = max_tokens // 2
|
125 |
+
|
126 |
+
description = description or f"""Put any context string (like few-shot prompts)"""
|
127 |
+
|
128 |
+
with gr.Blocks() as demo_text_completion:
|
129 |
+
if title:
|
130 |
+
gr.Markdown(title)
|
131 |
+
if description:
|
132 |
+
gr.Markdown(description)
|
133 |
+
with gr.Row():
|
134 |
+
txt = gr.Textbox(
|
135 |
+
scale=4,
|
136 |
+
lines=16,
|
137 |
+
show_label=False,
|
138 |
+
placeholder="Enter any free form text and submit",
|
139 |
+
container=False,
|
140 |
+
)
|
141 |
+
with gr.Row():
|
142 |
+
submit_button = gr.Button('Submit', variant='primary', scale=9)
|
143 |
+
stop_button = gr.Button('Stop', variant='stop', scale=9, visible=False)
|
144 |
+
num_tokens = Textbox(
|
145 |
+
container=False,
|
146 |
+
show_label=False,
|
147 |
+
label="num_tokens",
|
148 |
+
placeholder="0 tokens",
|
149 |
+
scale=1,
|
150 |
+
interactive=False,
|
151 |
+
min_width=10
|
152 |
+
)
|
153 |
+
with gr.Row():
|
154 |
+
temp_input = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
|
155 |
+
length_input = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
|
156 |
+
stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>,<|im_end|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
|
157 |
+
examples = gr.Examples(
|
158 |
+
examples=[
|
159 |
+
["The following is the recite the declaration of independence:",]
|
160 |
+
],
|
161 |
+
inputs=[txt, temp_input, length_input, stop_strings],
|
162 |
+
# outputs=[txt]
|
163 |
+
)
|
164 |
+
# ! Handle stop button
|
165 |
+
submit_trigger = submit_button.click
|
166 |
+
submit_event = submit_button.click(
|
167 |
+
# submit_trigger,
|
168 |
+
generate_text_completion_stream_engine,
|
169 |
+
[txt, temp_input, length_input, stop_strings],
|
170 |
+
[txt, num_tokens],
|
171 |
+
# api_name=False,
|
172 |
+
# queue=False,
|
173 |
+
)
|
174 |
+
|
175 |
+
submit_trigger(
|
176 |
+
lambda: (
|
177 |
+
Button(visible=False), Button(visible=True),
|
178 |
+
),
|
179 |
+
None,
|
180 |
+
[submit_button, stop_button],
|
181 |
+
api_name=False,
|
182 |
+
queue=False,
|
183 |
+
)
|
184 |
+
submit_event.then(
|
185 |
+
lambda: (Button(visible=True), Button(visible=False)),
|
186 |
+
None,
|
187 |
+
[submit_button, stop_button],
|
188 |
+
api_name=False,
|
189 |
+
queue=False,
|
190 |
+
)
|
191 |
+
stop_button.click(
|
192 |
+
None,
|
193 |
+
None,
|
194 |
+
None,
|
195 |
+
cancels=submit_event,
|
196 |
+
api_name=False,
|
197 |
+
)
|
198 |
+
|
199 |
+
return demo_text_completion
|
multipurpose_chatbot/engines/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
multipurpose_chatbot/engines/__init__.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .base_engine import BaseEngine
|
3 |
+
|
4 |
+
BACKENDS = [
|
5 |
+
"mlx",
|
6 |
+
"vllm",
|
7 |
+
"transformers",
|
8 |
+
"llama_cpp",
|
9 |
+
# "llava_llama_cpp",
|
10 |
+
"debug",
|
11 |
+
"sealmmm_transformers",
|
12 |
+
]
|
13 |
+
|
14 |
+
ENGINE_LOADED = False
|
15 |
+
|
16 |
+
def load_multipurpose_chatbot_engine(backend: str):
|
17 |
+
# ! lazy import other engines
|
18 |
+
global ENGINE_LOADED
|
19 |
+
assert backend in BACKENDS, f'{backend} not in {BACKENDS}'
|
20 |
+
if ENGINE_LOADED:
|
21 |
+
raise RuntimeError(f'{ENGINE_LOADED=} this means load_multipurpose_chatbot_engine has already been called! Check your codes.')
|
22 |
+
print(f'Load model from {backend}')
|
23 |
+
if backend == "mlx":
|
24 |
+
from .mlx_engine import MlxEngine
|
25 |
+
model_engine = MlxEngine()
|
26 |
+
elif backend == 'vllm':
|
27 |
+
from .vllm_engine import VllmEngine
|
28 |
+
model_engine = VllmEngine()
|
29 |
+
elif backend == 'transformers':
|
30 |
+
from .transformers_engine import TransformersEngine
|
31 |
+
model_engine = TransformersEngine()
|
32 |
+
elif backend == 'llama_cpp':
|
33 |
+
from .llama_cpp_engine import LlamaCppEngine
|
34 |
+
model_engine = LlamaCppEngine()
|
35 |
+
# ! llava_llama_cpp currently not done due to bugs
|
36 |
+
# elif backend == 'llava_llama_cpp':
|
37 |
+
# from .llava_llama_cpp_engine import LlavaLlamaCppEngine
|
38 |
+
# model_engine = LlavaLlamaCppEngine()
|
39 |
+
elif backend == 'debug':
|
40 |
+
from .debug_engine import DebugEngine
|
41 |
+
model_engine = DebugEngine()
|
42 |
+
elif backend == 'sealmmm_transformers':
|
43 |
+
from .sealmmm_engine import SeaLMMMv0Engine
|
44 |
+
model_engine = SeaLMMMv0Engine()
|
45 |
+
else:
|
46 |
+
raise ValueError(f'backend invalid: {BACKENDS} vs {backend}')
|
47 |
+
|
48 |
+
model_engine.load_model()
|
49 |
+
ENGINE_LOADED = True
|
50 |
+
return model_engine
|
51 |
+
# ! add more llama.cpp engine here.
|
52 |
+
|
53 |
+
|
multipurpose_chatbot/engines/base_engine.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
# ! Avoid importing transformers
|
5 |
+
# from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
6 |
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
7 |
+
import time
|
8 |
+
|
9 |
+
|
10 |
+
class BaseEngine(object):
|
11 |
+
def __init__(self, **kwargs) -> None:
|
12 |
+
pass
|
13 |
+
|
14 |
+
@property
|
15 |
+
def max_position_embeddings(self) -> int:
|
16 |
+
return 10000
|
17 |
+
|
18 |
+
@property
|
19 |
+
def tokenizer(self):
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
def load_model(self, ):
|
23 |
+
raise NotImplementedError
|
24 |
+
|
25 |
+
def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
|
26 |
+
"""
|
27 |
+
return string convo, add_special_tokens should be added later
|
28 |
+
"""
|
29 |
+
bos_token = self.tokenizer.bos_token
|
30 |
+
eos_token = self.tokenizer.eos_token
|
31 |
+
if not add_special_tokens:
|
32 |
+
# prevent bos being added to string
|
33 |
+
self.tokenizer.bos_token = ""
|
34 |
+
self.tokenizer.eos_token = ""
|
35 |
+
full_prompt = self.tokenizer.apply_chat_template(
|
36 |
+
conversations, add_generation_prompt=add_generation_prompt,
|
37 |
+
tokenize=False,
|
38 |
+
)
|
39 |
+
self.tokenizer.bos_token = bos_token
|
40 |
+
self.tokenizer.eos_token = eos_token
|
41 |
+
return full_prompt
|
42 |
+
|
multipurpose_chatbot/engines/debug_engine.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from huggingface_hub import snapshot_download
|
4 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
5 |
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
6 |
+
import time
|
7 |
+
|
8 |
+
from .base_engine import BaseEngine
|
9 |
+
|
10 |
+
from ..configs import (
|
11 |
+
MODEL_PATH,
|
12 |
+
)
|
13 |
+
|
14 |
+
FAKE_MODEL_PATH = os.environ.get("FAKE_MODEL_PATH", MODEL_PATH)
|
15 |
+
FAKE_RESPONSE = "Wow that's very very cool, please try again."
|
16 |
+
|
17 |
+
|
18 |
+
class DebugEngine(BaseEngine):
|
19 |
+
"""
|
20 |
+
It will always yield FAKE_RESPONSE
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, **kwargs) -> None:
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
self._model = None
|
26 |
+
self._tokenizer = None
|
27 |
+
|
28 |
+
@property
|
29 |
+
def tokenizer(self) -> PreTrainedTokenizer:
|
30 |
+
if self._tokenizer is None:
|
31 |
+
self._tokenizer = AutoTokenizer.from_pretrained(FAKE_MODEL_PATH, trust_remote_code=True)
|
32 |
+
return self._tokenizer
|
33 |
+
|
34 |
+
def load_model(self):
|
35 |
+
print(f"Load fake model with tokenizer: {self.tokenizer}")
|
36 |
+
|
37 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
38 |
+
|
39 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
40 |
+
response = FAKE_RESPONSE
|
41 |
+
for i in range(len(response)):
|
42 |
+
time.sleep(0.01)
|
43 |
+
yield response[:i], num_tokens
|
44 |
+
|
45 |
+
num_tokens = len(self.tokenizer.encode(prompt + response))
|
46 |
+
yield response, num_tokens
|
47 |
+
|
48 |
+
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
49 |
+
return [p + " -- Test" for p in prompts]
|
multipurpose_chatbot/engines/llama_cpp_engine.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from typing import Any, Iterator
|
6 |
+
from typing import Iterator, List, Optional, Tuple
|
7 |
+
import filelock
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
from gradio.routes import Request
|
12 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
13 |
+
from gradio.helpers import special_args
|
14 |
+
import anyio
|
15 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
16 |
+
|
17 |
+
from gradio_client.documentation import document, set_documentation_group
|
18 |
+
|
19 |
+
from typing import List, Optional, Union, Dict, Tuple
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
import types
|
23 |
+
|
24 |
+
from gradio.components import Button
|
25 |
+
from gradio.events import Dependency, EventListenerMethod
|
26 |
+
|
27 |
+
import types
|
28 |
+
import sys
|
29 |
+
|
30 |
+
from .base_engine import BaseEngine
|
31 |
+
|
32 |
+
# ! Remember to use static cache
|
33 |
+
|
34 |
+
from ..configs import (
|
35 |
+
MODEL_PATH,
|
36 |
+
DEFAULT_CHAT_TEMPLATE,
|
37 |
+
N_CTX,
|
38 |
+
N_GPU_LAYERS,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def encode_tokenize(self, prompt: str, **kwargs):
|
44 |
+
"""Mimic behavior of transformers tokenizer"""
|
45 |
+
prompt_tokens: List[int] = (
|
46 |
+
(
|
47 |
+
self.tokenize(prompt.encode("utf-8"), special=True)
|
48 |
+
if prompt != ""
|
49 |
+
else [self.token_bos()]
|
50 |
+
)
|
51 |
+
if isinstance(prompt, str)
|
52 |
+
else prompt
|
53 |
+
)
|
54 |
+
return prompt_tokens
|
55 |
+
|
56 |
+
|
57 |
+
conversations = [
|
58 |
+
{"role": "system", "content": "You are good."},
|
59 |
+
{"role": "user", "content": "Hello."},
|
60 |
+
{"role": "assistant", "content": "Hi."},
|
61 |
+
]
|
62 |
+
|
63 |
+
|
64 |
+
class LlamaCppEngine(BaseEngine):
|
65 |
+
"""
|
66 |
+
need to create an engine.tokenizer.encode(text) method
|
67 |
+
"""
|
68 |
+
@property
|
69 |
+
def max_position_embeddings(self) -> int:
|
70 |
+
# raise ValueError
|
71 |
+
return self._model.context_params.n_ctx
|
72 |
+
|
73 |
+
def apply_chat_template(self, conversations, add_generation_prompt: bool, add_special_tokens=False, **kwargs) -> str:
|
74 |
+
"""
|
75 |
+
return string convo, add_special_tokens should be added later
|
76 |
+
remember to remove <s> if any,
|
77 |
+
"""
|
78 |
+
from llama_cpp.llama_chat_format import Jinja2ChatFormatter
|
79 |
+
|
80 |
+
formatter = Jinja2ChatFormatter(
|
81 |
+
template=self._model.metadata['tokenizer.chat_template'],
|
82 |
+
# bos_token=self._model._model.token_get_text(self._model.token_bos()),
|
83 |
+
bos_token="",
|
84 |
+
eos_token=self._model._model.token_get_text(self._model.token_eos()),
|
85 |
+
add_generation_prompt=add_generation_prompt,
|
86 |
+
)
|
87 |
+
|
88 |
+
full_prompt = formatter(messages=conversations).prompt
|
89 |
+
# ! it may has bos
|
90 |
+
return full_prompt
|
91 |
+
|
92 |
+
@property
|
93 |
+
def tokenizer(self):
|
94 |
+
return self._model
|
95 |
+
|
96 |
+
def load_model(self):
|
97 |
+
# from transformers import AutoTokenizer, AutoModelForCausalLM
|
98 |
+
|
99 |
+
from llama_cpp import Llama
|
100 |
+
self.model_path = MODEL_PATH
|
101 |
+
self._model = Llama(
|
102 |
+
model_path=self.model_path,
|
103 |
+
n_gpu_layers=N_GPU_LAYERS, # Uncomment to use GPU acceleration
|
104 |
+
# seed=1337, # Uncomment to set a specific seed
|
105 |
+
n_ctx=N_CTX, # Uncomment to increase the context window
|
106 |
+
)
|
107 |
+
self._tokenizer = self._model
|
108 |
+
self._model.encode = types.MethodType(encode_tokenize, self._model)
|
109 |
+
print(f'Load model: {self.model_path=} | {N_GPU_LAYERS=} | {N_CTX=}')
|
110 |
+
|
111 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
112 |
+
stop_strings = list(stop_strings) if stop_strings is not None else []
|
113 |
+
stop_strings = list(set(stop_strings + ["</s>", "<|im_end|>"]))
|
114 |
+
generator = self._model(
|
115 |
+
prompt,
|
116 |
+
max_tokens=max_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
|
117 |
+
temperature=temperature,
|
118 |
+
stop=stop_strings, # Stop generating just before the model would generate a new question
|
119 |
+
stream=True,
|
120 |
+
)
|
121 |
+
response = ""
|
122 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
123 |
+
for g in generator:
|
124 |
+
response += g['choices'][0]['text']
|
125 |
+
yield response, num_tokens
|
126 |
+
|
127 |
+
if response is not None and len(response) > 0:
|
128 |
+
num_tokens = len(self.tokenizer.encode(prompt + response))
|
129 |
+
yield response, num_tokens
|
130 |
+
|
131 |
+
|
multipurpose_chatbot/engines/llava_llama_cpp_engine.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from typing import Any, Iterator
|
6 |
+
from typing import Iterator, List, Optional, Tuple
|
7 |
+
import filelock
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
from gradio.routes import Request
|
12 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
13 |
+
from gradio.helpers import special_args
|
14 |
+
import anyio
|
15 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
16 |
+
|
17 |
+
from gradio_client.documentation import document, set_documentation_group
|
18 |
+
|
19 |
+
from typing import List, Optional, Union, Dict, Tuple
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
import types
|
23 |
+
|
24 |
+
from gradio.components import Button
|
25 |
+
from gradio.events import Dependency, EventListenerMethod
|
26 |
+
|
27 |
+
import types
|
28 |
+
import sys
|
29 |
+
|
30 |
+
from .base_engine import BaseEngine
|
31 |
+
|
32 |
+
# ! Remember to use static cache
|
33 |
+
|
34 |
+
from ..configs import (
|
35 |
+
MODEL_PATH,
|
36 |
+
DEFAULT_CHAT_TEMPLATE,
|
37 |
+
N_CTX,
|
38 |
+
N_GPU_LAYERS,
|
39 |
+
IMAGE_TOKEN,
|
40 |
+
IMAGE_TOKEN_INTERACTIVE,
|
41 |
+
IMAGE_TOKEN_LENGTH,
|
42 |
+
MAX_PACHES,
|
43 |
+
)
|
44 |
+
|
45 |
+
from .llama_cpp_engine import (
|
46 |
+
encode_tokenize,
|
47 |
+
LlamaCppEngine,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
# resource: https://llama-cpp-python.readthedocs.io/en/latest/#multi-modal-models
|
53 |
+
|
54 |
+
import base64
|
55 |
+
|
56 |
+
def image_to_base64_data_uri(file_path):
|
57 |
+
with open(file_path, "rb") as img_file:
|
58 |
+
base64_data = base64.b64encode(img_file.read()).decode('utf-8')
|
59 |
+
return f"data:image/png;base64,{base64_data}"
|
60 |
+
|
61 |
+
|
62 |
+
# file_path = 'file_path.png'
|
63 |
+
# data_uri = image_to_base64_data_uri(file_path)
|
64 |
+
|
65 |
+
# data_uri = image_to_base64_data_uri(file_path)
|
66 |
+
|
67 |
+
# messages = [
|
68 |
+
# {"role": "system", "content": "You are an assistant who perfectly describes images."},
|
69 |
+
# {
|
70 |
+
# "role": "user",
|
71 |
+
# "content": [
|
72 |
+
# {"type": "image_url", "image_url": {"url": data_uri }},
|
73 |
+
# {"type" : "text", "text": "Describe this image in detail please."}
|
74 |
+
# ]
|
75 |
+
# }
|
76 |
+
# ]
|
77 |
+
|
78 |
+
|
79 |
+
def llava_15_chat_handler_call(
|
80 |
+
self,
|
81 |
+
*,
|
82 |
+
llama: Any,
|
83 |
+
# messages: List[Any],
|
84 |
+
prompt: Union[str, List[int]],
|
85 |
+
image_data_uris: Optional[List[Any]] = None,
|
86 |
+
image_token: str = None,
|
87 |
+
functions: Optional[List[Any]] = None,
|
88 |
+
function_call: Optional[Any] = None,
|
89 |
+
tools: Optional[List[Any]] = None,
|
90 |
+
tool_choice: Optional[Any] = None,
|
91 |
+
temperature: float = 0.2,
|
92 |
+
top_p: float = 0.95,
|
93 |
+
top_k: int = 40,
|
94 |
+
min_p: float = 0.05,
|
95 |
+
typical_p: float = 1.0,
|
96 |
+
stream: bool = False,
|
97 |
+
stop: Optional[Union[str, List[str]]] = [],
|
98 |
+
response_format: Optional[
|
99 |
+
Any
|
100 |
+
] = None,
|
101 |
+
max_tokens: Optional[int] = None,
|
102 |
+
presence_penalty: float = 0.0,
|
103 |
+
frequency_penalty: float = 0.0,
|
104 |
+
repeat_penalty: float = 1.1,
|
105 |
+
tfs_z: float = 1.0,
|
106 |
+
mirostat_mode: int = 0,
|
107 |
+
mirostat_tau: float = 5.0,
|
108 |
+
mirostat_eta: float = 0.1,
|
109 |
+
model: Optional[str] = None,
|
110 |
+
logits_processor: Optional[Any] = None,
|
111 |
+
grammar: Optional[Any] = None,
|
112 |
+
**kwargs, # type: ignore
|
113 |
+
):
|
114 |
+
from llama_cpp.llama_chat_format import (
|
115 |
+
ctypes,
|
116 |
+
suppress_stdout_stderr,
|
117 |
+
)
|
118 |
+
assert (
|
119 |
+
llama.context_params.logits_all is True
|
120 |
+
) # BUG: logits_all=True is required for llava
|
121 |
+
assert self.clip_ctx is not None
|
122 |
+
# ! split prompt into different parts
|
123 |
+
assert image_token is not None
|
124 |
+
prompt_parts = prompt.split(image_token)
|
125 |
+
# assert len(prompt_parts)
|
126 |
+
assert len(prompt_parts) == len(image_data_uris) + 1, f'invalid {len(prompt_parts)=} != {len(image_data_uris)=}'
|
127 |
+
llama.reset()
|
128 |
+
prefix = prompt_parts[0]
|
129 |
+
remaining_texts = prompt_parts[1:]
|
130 |
+
llama.reset()
|
131 |
+
llama.eval(llama.tokenize(prefix.encode("utf8"), add_bos=True))
|
132 |
+
for index, (image_uri, prompt_p) in enumerate(zip(image_data_uris, remaining_texts)):
|
133 |
+
image_bytes = self.load_image(image_uri)
|
134 |
+
import array
|
135 |
+
data_array = array.array("B", image_bytes)
|
136 |
+
c_ubyte_ptr = (
|
137 |
+
ctypes.c_ubyte * len(data_array)
|
138 |
+
).from_buffer(data_array)
|
139 |
+
with suppress_stdout_stderr(disable=self.verbose):
|
140 |
+
embed = (
|
141 |
+
self._llava_cpp.llava_image_embed_make_with_bytes(
|
142 |
+
self.clip_ctx,
|
143 |
+
llama.context_params.n_threads,
|
144 |
+
c_ubyte_ptr,
|
145 |
+
len(image_bytes),
|
146 |
+
)
|
147 |
+
)
|
148 |
+
try:
|
149 |
+
n_past = ctypes.c_int(llama.n_tokens)
|
150 |
+
n_past_p = ctypes.pointer(n_past)
|
151 |
+
with suppress_stdout_stderr(disable=self.verbose):
|
152 |
+
self._llava_cpp.llava_eval_image_embed(
|
153 |
+
llama.ctx,
|
154 |
+
embed,
|
155 |
+
llama.n_batch,
|
156 |
+
n_past_p,
|
157 |
+
)
|
158 |
+
assert llama.n_ctx() >= n_past.value
|
159 |
+
llama.n_tokens = n_past.value
|
160 |
+
finally:
|
161 |
+
with suppress_stdout_stderr(disable=self.verbose):
|
162 |
+
self._llava_cpp.llava_image_embed_free(embed)
|
163 |
+
|
164 |
+
llama.eval(llama.tokenize(prompt_p.encode("utf8"), add_bos=False))
|
165 |
+
assert llama.n_ctx() >= llama.n_tokens
|
166 |
+
|
167 |
+
prompt = llama.input_ids[: llama.n_tokens].tolist()
|
168 |
+
# from llava-1.5
|
169 |
+
return llama.create_completion(
|
170 |
+
prompt=prompt,
|
171 |
+
temperature=temperature,
|
172 |
+
top_p=top_p,
|
173 |
+
top_k=top_k,
|
174 |
+
min_p=min_p,
|
175 |
+
typical_p=typical_p,
|
176 |
+
stream=stream,
|
177 |
+
stop=stop,
|
178 |
+
max_tokens=max_tokens,
|
179 |
+
presence_penalty=presence_penalty,
|
180 |
+
frequency_penalty=frequency_penalty,
|
181 |
+
repeat_penalty=repeat_penalty,
|
182 |
+
tfs_z=tfs_z,
|
183 |
+
mirostat_mode=mirostat_mode,
|
184 |
+
mirostat_tau=mirostat_tau,
|
185 |
+
mirostat_eta=mirostat_eta,
|
186 |
+
model=model,
|
187 |
+
logits_processor=logits_processor,
|
188 |
+
grammar=grammar,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
class LlavaLlamaCppEngine(LlamaCppEngine):
|
194 |
+
"""
|
195 |
+
Still in development, expect BUGS
|
196 |
+
|
197 |
+
ERROR: could not know why
|
198 |
+
objc[61055]: Class GGMLMetalClass is implemented in both miniconda3/envs/native/lib/python3.12/site-packages/llama_cpp/libllama.dylib (0x12cb40290) and miniconda3/envs/native/lib/python3.12/site-packages/llama_cpp/libllava.dylib (0x12d9c8290). One of the two will be used. Which one is undefined.
|
199 |
+
|
200 |
+
"""
|
201 |
+
@property
|
202 |
+
def image_token(self):
|
203 |
+
return IMAGE_TOKEN
|
204 |
+
|
205 |
+
def get_multimodal_tokens(self, full_prompt, image_paths=None):
|
206 |
+
num_tokens = len(self.tokenizer.encode(full_prompt))
|
207 |
+
for image_path in image_paths:
|
208 |
+
num_tokens += IMAGE_TOKEN_LENGTH * MAX_PACHES
|
209 |
+
return num_tokens
|
210 |
+
|
211 |
+
def load_model(self):
|
212 |
+
# from transformers import AutoTokenizer, AutoModelForCausalLM
|
213 |
+
from llama_cpp import Llama
|
214 |
+
from llama_cpp.llama_chat_format import Llava15ChatHandler
|
215 |
+
model_dir = os.path.dirname(MODEL_PATH)
|
216 |
+
self.chat_handler = Llava15ChatHandler(clip_model_path=os.path.join(model_dir, "mmproj.bin"))
|
217 |
+
|
218 |
+
self.chat_handler.__call__ = types.MethodType(llava_15_chat_handler_call, self.chat_handler)
|
219 |
+
|
220 |
+
self.model_path = MODEL_PATH
|
221 |
+
self._model = Llama(
|
222 |
+
model_path=self.model_path,
|
223 |
+
n_gpu_layers=N_GPU_LAYERS, # Uncomment to use GPU acceleration
|
224 |
+
# seed=1337, # Uncomment to set a specific seed
|
225 |
+
chat_handler=self.chat_handler,
|
226 |
+
n_ctx=N_CTX, # Uncomment to increase the context window
|
227 |
+
logits_all=True, # needed to make llava work
|
228 |
+
)
|
229 |
+
self._tokenizer = self._model
|
230 |
+
self._model.encode = types.MethodType(encode_tokenize, self._model)
|
231 |
+
print(f'Load model: {self.model_path=} | {N_GPU_LAYERS=} | {N_CTX=}')
|
232 |
+
|
233 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
234 |
+
image_paths = kwargs.get("image_paths", [])
|
235 |
+
|
236 |
+
image_data_uris = [
|
237 |
+
image_to_base64_data_uri(ip)
|
238 |
+
for ip in image_paths
|
239 |
+
]
|
240 |
+
|
241 |
+
stop_strings = list(stop_strings) if stop_strings is not None else []
|
242 |
+
stop_strings = list(set(stop_strings + ["</s>", "<|im_end|>"]))
|
243 |
+
# generator = self._model(
|
244 |
+
generator = self.chat_handler(
|
245 |
+
prompt=prompt,
|
246 |
+
image_data_uris=image_data_uris,
|
247 |
+
image_token=self.image_token,
|
248 |
+
max_tokens=max_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
|
249 |
+
temperature=temperature,
|
250 |
+
stop=stop_strings, # Stop generating just before the model would generate a new question
|
251 |
+
stream=True,
|
252 |
+
)
|
253 |
+
response = ""
|
254 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
255 |
+
for g in generator:
|
256 |
+
response += g['choices'][0]['text']
|
257 |
+
yield response, num_tokens
|
258 |
+
|
259 |
+
if response is not None and len(response) > 0:
|
260 |
+
num_tokens = len(self.tokenizer.encode(prompt + response))
|
261 |
+
yield response, num_tokens
|
262 |
+
|
263 |
+
|
264 |
+
"""
|
265 |
+
|
266 |
+
export MODEL_PATH
|
267 |
+
BACKEND=llama_cpp
|
268 |
+
MODEL_PATH=/Users/nguyenxuanphi/Desktop/projects/cache/seallms/SeaLLMs/SeaLLM-7B-v2-gguf/seallm-v2.chatml.Q4_K_M.gguf
|
269 |
+
N_CTX=4096
|
270 |
+
python app.py
|
271 |
+
|
272 |
+
|
273 |
+
export BACKEND=llava_llama_cpp
|
274 |
+
export MODEL_PATH=/Users/nguyenxuanphi/Desktop/projects/cache/llava/llava-1.5/ggml-model-q4_k.gguf
|
275 |
+
export N_CTX=4096
|
276 |
+
export IMAGE_TOKEN="<image>"
|
277 |
+
python app.py
|
278 |
+
|
279 |
+
|
280 |
+
"""
|
multipurpose_chatbot/engines/mlx_engine.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import mlx.core as mx
|
4 |
+
import mlx.nn as nn
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
7 |
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
|
8 |
+
import time
|
9 |
+
from mlx_lm import load, generate
|
10 |
+
from mlx_lm.utils import generate_step
|
11 |
+
|
12 |
+
from .base_engine import BaseEngine
|
13 |
+
|
14 |
+
from ..configs import (
|
15 |
+
MODEL_PATH,
|
16 |
+
)
|
17 |
+
|
18 |
+
def generate_string(
|
19 |
+
model: nn.Module,
|
20 |
+
tokenizer: PreTrainedTokenizer,
|
21 |
+
prompt: str,
|
22 |
+
temp: float = 0.0,
|
23 |
+
max_tokens: int = 100,
|
24 |
+
verbose: bool = False,
|
25 |
+
formatter: Callable = None,
|
26 |
+
repetition_penalty: Optional[float] = None,
|
27 |
+
repetition_context_size: Optional[int] = None,
|
28 |
+
stop_strings: Optional[Tuple[str]] = None
|
29 |
+
):
|
30 |
+
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
31 |
+
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
|
32 |
+
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
|
33 |
+
|
34 |
+
tic = time.perf_counter()
|
35 |
+
tokens = []
|
36 |
+
skip = 0
|
37 |
+
REPLACEMENT_CHAR = "\ufffd"
|
38 |
+
|
39 |
+
for (token, prob), n in zip(
|
40 |
+
generate_step(
|
41 |
+
prompt_tokens,
|
42 |
+
model,
|
43 |
+
temp,
|
44 |
+
repetition_penalty,
|
45 |
+
repetition_context_size,
|
46 |
+
),
|
47 |
+
range(max_tokens),
|
48 |
+
):
|
49 |
+
if token == tokenizer.eos_token_id:
|
50 |
+
break
|
51 |
+
if n == 0:
|
52 |
+
prompt_time = time.perf_counter() - tic
|
53 |
+
tic = time.perf_counter()
|
54 |
+
tokens.append(token.item())
|
55 |
+
if stop_strings is not None:
|
56 |
+
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
57 |
+
if token_string.strip().endswith(stop_strings):
|
58 |
+
break
|
59 |
+
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
60 |
+
return token_string
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def generate_yield_string(
|
65 |
+
model: nn.Module,
|
66 |
+
tokenizer: PreTrainedTokenizer,
|
67 |
+
prompt: str,
|
68 |
+
temp: float = 0.0,
|
69 |
+
max_tokens: int = 100,
|
70 |
+
verbose: bool = False,
|
71 |
+
formatter: Callable = None,
|
72 |
+
repetition_penalty: Optional[float] = None,
|
73 |
+
repetition_context_size: Optional[int] = None,
|
74 |
+
stop_strings: Optional[Tuple[str]] = None
|
75 |
+
):
|
76 |
+
"""
|
77 |
+
Generate text from the model.
|
78 |
+
Args:
|
79 |
+
model (nn.Module): The language model.
|
80 |
+
tokenizer (PreTrainedTokenizer): The tokenizer.
|
81 |
+
prompt (str): The string prompt.
|
82 |
+
temp (float): The temperature for sampling (default 0).
|
83 |
+
max_tokens (int): The maximum number of tokens (default 100).
|
84 |
+
verbose (bool): If ``True``, print tokens and timing information
|
85 |
+
(default ``False``).
|
86 |
+
formatter (Optional[Callable]): A function which takes a token and a
|
87 |
+
probability and displays it.
|
88 |
+
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
89 |
+
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
90 |
+
"""
|
91 |
+
if verbose:
|
92 |
+
print("=" * 10)
|
93 |
+
print("Prompt:", prompt)
|
94 |
+
stop_strings = stop_strings if stop_strings is None or isinstance(stop_strings, tuple) else tuple(stop_strings)
|
95 |
+
assert stop_strings is None or isinstance(stop_strings, tuple), f'invalid {stop_strings}'
|
96 |
+
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
97 |
+
tic = time.perf_counter()
|
98 |
+
tokens = []
|
99 |
+
skip = 0
|
100 |
+
REPLACEMENT_CHAR = "\ufffd"
|
101 |
+
for (token, prob), n in zip(
|
102 |
+
generate_step(
|
103 |
+
prompt_tokens,
|
104 |
+
model,
|
105 |
+
temp,
|
106 |
+
repetition_penalty,
|
107 |
+
repetition_context_size,
|
108 |
+
),
|
109 |
+
range(max_tokens),
|
110 |
+
):
|
111 |
+
if token == tokenizer.eos_token_id:
|
112 |
+
break
|
113 |
+
# if n == 0:
|
114 |
+
# prompt_time = time.perf_counter() - tic
|
115 |
+
# tic = time.perf_counter()
|
116 |
+
tokens.append(token.item())
|
117 |
+
# if verbose:
|
118 |
+
# s = tokenizer.decode(tokens)
|
119 |
+
# if formatter:
|
120 |
+
# formatter(s[skip:], prob.item())
|
121 |
+
# skip = len(s)
|
122 |
+
# elif REPLACEMENT_CHAR not in s:
|
123 |
+
# print(s[skip:], end="", flush=True)
|
124 |
+
# skip = len(s)
|
125 |
+
token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
126 |
+
yield token_string
|
127 |
+
if stop_strings is not None and token_string.strip().endswith(stop_strings):
|
128 |
+
break
|
129 |
+
|
130 |
+
# token_count = len(tokens)
|
131 |
+
# token_string = tokenizer.decode(tokens).replace(REPLACEMENT_CHAR, "")
|
132 |
+
|
133 |
+
# if verbose:
|
134 |
+
# print(token_string[skip:], flush=True)
|
135 |
+
# gen_time = time.perf_counter() - tic
|
136 |
+
# print("=" * 10)
|
137 |
+
# if token_count == 0:
|
138 |
+
# print("No tokens generated for this prompt")
|
139 |
+
# return
|
140 |
+
# prompt_tps = prompt_tokens.size / prompt_time
|
141 |
+
# gen_tps = (token_count - 1) / gen_time
|
142 |
+
# print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
143 |
+
# print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
144 |
+
|
145 |
+
# return token_string
|
146 |
+
|
147 |
+
|
148 |
+
class MlxEngine(BaseEngine):
|
149 |
+
|
150 |
+
def __init__(self, **kwargs) -> None:
|
151 |
+
super().__init__(**kwargs)
|
152 |
+
self._model = None
|
153 |
+
self._tokenizer = None
|
154 |
+
|
155 |
+
@property
|
156 |
+
def tokenizer(self) -> PreTrainedTokenizer:
|
157 |
+
return self._tokenizer
|
158 |
+
|
159 |
+
def load_model(self, ):
|
160 |
+
model_path = MODEL_PATH
|
161 |
+
self._model, self._tokenizer = load(model_path)
|
162 |
+
self.model_path = model_path
|
163 |
+
print(f'Load MLX model from {model_path}')
|
164 |
+
|
165 |
+
|
166 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
167 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
168 |
+
response = None
|
169 |
+
for response in generate_yield_string(
|
170 |
+
self._model, self._tokenizer,
|
171 |
+
prompt, temp=temperature, max_tokens=max_tokens,
|
172 |
+
repetition_penalty=kwargs.get("repetition_penalty", None),
|
173 |
+
stop_strings=stop_strings,
|
174 |
+
):
|
175 |
+
yield response, num_tokens
|
176 |
+
if response is not None:
|
177 |
+
full_text = prompt + response
|
178 |
+
num_tokens = len(self.tokenizer.encode(full_text))
|
179 |
+
yield response, num_tokens
|
180 |
+
|
181 |
+
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
182 |
+
"""
|
183 |
+
! MLX does not support
|
184 |
+
"""
|
185 |
+
responses = [
|
186 |
+
generate_string(
|
187 |
+
self._model, self._tokenizer,
|
188 |
+
s, temp=temperature, max_tokens=max_tokens,
|
189 |
+
repetition_penalty=kwargs.get("repetition_penalty", None),
|
190 |
+
stop_strings=stop_strings,
|
191 |
+
)
|
192 |
+
for s in prompts
|
193 |
+
]
|
194 |
+
return responses
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
multipurpose_chatbot/engines/modeling_sealmm.py
ADDED
@@ -0,0 +1,1091 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import nullcontext
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from transformers import PreTrainedModel
|
10 |
+
from transformers.activations import ACT2FN
|
11 |
+
from transformers.cache_utils import Cache
|
12 |
+
from transformers.modeling_outputs import ModelOutput
|
13 |
+
from transformers.models.clip.configuration_clip import CLIPConfig
|
14 |
+
from transformers.utils import (
|
15 |
+
add_start_docstrings,
|
16 |
+
add_start_docstrings_to_model_forward,
|
17 |
+
logging,
|
18 |
+
replace_return_docstrings,
|
19 |
+
)
|
20 |
+
from transformers import AutoModel, AutoModelForCausalLM
|
21 |
+
from transformers.models.llava.configuration_llava import LlavaConfig
|
22 |
+
|
23 |
+
from transformers.models.llava.modeling_llava import (
|
24 |
+
LlavaCausalLMOutputWithPast,
|
25 |
+
LlavaMultiModalProjector,
|
26 |
+
LlavaPreTrainedModel,
|
27 |
+
LLAVA_START_DOCSTRING,
|
28 |
+
LLAVA_INPUTS_DOCSTRING,
|
29 |
+
LlavaForConditionalGeneration,
|
30 |
+
)
|
31 |
+
|
32 |
+
from transformers.models.blip_2.configuration_blip_2 import (
|
33 |
+
Blip2Config,
|
34 |
+
Blip2QFormerConfig,
|
35 |
+
)
|
36 |
+
import os
|
37 |
+
from transformers.models.blip_2.modeling_blip_2 import (
|
38 |
+
Blip2Config,
|
39 |
+
Blip2QFormerModel,
|
40 |
+
Blip2PreTrainedModel,
|
41 |
+
BLIP_2_INPUTS_DOCSTRING,
|
42 |
+
)
|
43 |
+
|
44 |
+
from transformers.utils.import_utils import is_flash_attn_greater_or_equal_2_10
|
45 |
+
|
46 |
+
# from .configuration_sealmm import SeaLMMConfig
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
# _CONFIG_FOR_DOC = "LlavaConfig"
|
51 |
+
_CONFIG_FOR_DOC = "SeaLMMConfig"
|
52 |
+
|
53 |
+
|
54 |
+
class SeaLMMConfig(LlavaConfig):
|
55 |
+
def __init__(self, *args, **kwargs):
|
56 |
+
self.projector_num_layers = kwargs.get("projector_num_layers", 1)
|
57 |
+
super().__init__(*args, **kwargs)
|
58 |
+
|
59 |
+
"""
|
60 |
+
Llava
|
61 |
+
|
62 |
+
vision_config.num_hidden_layers = vision_config.num_hidden_layers + config.vision_feature_layer + 1
|
63 |
+
# "num_hidden_layers": 24,
|
64 |
+
|
65 |
+
"""
|
66 |
+
|
67 |
+
IMAGE_TOKEN = "<|image|>"
|
68 |
+
DEBUG = bool(int(os.environ.get("DEBUG", "0")))
|
69 |
+
|
70 |
+
|
71 |
+
def by_sample_merge_input_ids_with_image_features(
|
72 |
+
self, image_features, inputs_embeds, input_ids, attention_mask=None, position_ids=None
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
input_ids: [tlen]
|
76 |
+
input_embeds: [tlen, dt]
|
77 |
+
img_embeds: [ilen, ifeat, di]
|
78 |
+
|
79 |
+
e.g:
|
80 |
+
input_ids: [
|
81 |
+
a b c d e f X g h i j k X l m
|
82 |
+
]
|
83 |
+
img_embeds: [3, ifeat, id] # img_embeds has padding
|
84 |
+
"""
|
85 |
+
num_images, num_image_patches, embed_dim = image_features.shape
|
86 |
+
sequence_length = input_ids.size(0)
|
87 |
+
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
88 |
+
assert not left_padding, f'should only use right padding'
|
89 |
+
# 1. Create a mask to know where special image tokens are
|
90 |
+
special_image_token_mask = input_ids == self.config.image_token_index
|
91 |
+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
92 |
+
# Compute the maximum embed dimension
|
93 |
+
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
94 |
+
|
95 |
+
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
96 |
+
from transformers.models.clip.modeling_clip import (
|
97 |
+
contrastive_loss,
|
98 |
+
clip_loss,
|
99 |
+
CLIPVisionModelOutput,
|
100 |
+
CLIPTextModelOutput,
|
101 |
+
CLIPOutput,
|
102 |
+
CLIPTextEmbeddings,
|
103 |
+
CLIPVisionEmbeddings,
|
104 |
+
CLIPAttention,
|
105 |
+
CLIPMLP,
|
106 |
+
CLIPEncoderLayer,
|
107 |
+
CLIPPreTrainedModel,
|
108 |
+
CLIPTextTransformer,
|
109 |
+
CLIPTextModel,
|
110 |
+
CLIPVisionTransformer,
|
111 |
+
CLIPVisionModel,
|
112 |
+
CLIPModel,
|
113 |
+
CLIPEncoder,
|
114 |
+
CLIPTextModelWithProjection,
|
115 |
+
CLIPVisionModelWithProjection,
|
116 |
+
CLIP_START_DOCSTRING,
|
117 |
+
CLIP_TEXT_INPUTS_DOCSTRING,
|
118 |
+
CLIP_VISION_INPUTS_DOCSTRING,
|
119 |
+
CLIP_INPUTS_DOCSTRING,
|
120 |
+
)
|
121 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
126 |
+
def _get_unpad_data(attention_mask):
|
127 |
+
import torch.nn.functional as F
|
128 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
129 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
130 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
131 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
132 |
+
return (
|
133 |
+
indices,
|
134 |
+
cu_seqlens,
|
135 |
+
max_seqlen_in_batch,
|
136 |
+
)
|
137 |
+
|
138 |
+
class CLIPFlashAttention2(CLIPAttention):
|
139 |
+
"""
|
140 |
+
CLIP flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
|
141 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
142 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
143 |
+
"""
|
144 |
+
def __init__(self, config, is_causal=False):
|
145 |
+
super().__init__(config)
|
146 |
+
self.is_causal = is_causal
|
147 |
+
|
148 |
+
def forward(
|
149 |
+
self,
|
150 |
+
hidden_states: torch.Tensor,
|
151 |
+
attention_mask: Optional[torch.Tensor] = None,
|
152 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
153 |
+
output_attentions: Optional[bool] = False,
|
154 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
155 |
+
"""Input shape: Batch x Time x Channel"""
|
156 |
+
if output_attentions:
|
157 |
+
raise ValueError("CLIPFlashAttention2 does not support output_attentions")
|
158 |
+
|
159 |
+
if self.is_causal and causal_attention_mask is None:
|
160 |
+
raise ValueError("CLIPFlashAttention2 has causal=True but no causal_attention_mask provided")
|
161 |
+
|
162 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
163 |
+
|
164 |
+
# [batch_size, tgt_len, embed_dim]
|
165 |
+
query_states = self.q_proj(hidden_states)
|
166 |
+
key_states = self.k_proj(hidden_states)
|
167 |
+
value_states = self.v_proj(hidden_states)
|
168 |
+
|
169 |
+
# [batch_size, tgt_len, embed_dim] -> [batch_size, tgt_len, num_heads, head_dim]
|
170 |
+
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous()
|
171 |
+
key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous()
|
172 |
+
value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim).contiguous()
|
173 |
+
|
174 |
+
attn_output = self._flash_attention_forward(
|
175 |
+
query_states=query_states,
|
176 |
+
key_states=key_states,
|
177 |
+
value_states=value_states,
|
178 |
+
attention_mask=attention_mask,
|
179 |
+
query_length=tgt_len,
|
180 |
+
dropout=self.dropout,
|
181 |
+
softmax_scale=self.scale,
|
182 |
+
)
|
183 |
+
# [batch_size, tgt_len, num_heads, head_dim] -> [batch_size, tgt_len, embed_dim]
|
184 |
+
attn_output = attn_output.view(bsz, tgt_len, embed_dim)
|
185 |
+
attn_output = self.out_proj(attn_output)
|
186 |
+
|
187 |
+
return attn_output, None
|
188 |
+
|
189 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
190 |
+
def _flash_attention_forward(
|
191 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
192 |
+
) -> torch.Tensor:
|
193 |
+
"""
|
194 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
195 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
query_states (`torch.Tensor`):
|
199 |
+
Input query states to be passed to Flash Attention API
|
200 |
+
key_states (`torch.Tensor`):
|
201 |
+
Input key states to be passed to Flash Attention API
|
202 |
+
value_states (`torch.Tensor`):
|
203 |
+
Input value states to be passed to Flash Attention API
|
204 |
+
attention_mask (`torch.Tensor`):
|
205 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
206 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
207 |
+
dropout (`int`, *optional*):
|
208 |
+
Attention dropout
|
209 |
+
softmax_scale (`float`, *optional*):
|
210 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
211 |
+
"""
|
212 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
213 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
214 |
+
# Contains at least one padding token in the sequence
|
215 |
+
if attention_mask is not None:
|
216 |
+
batch_size = query_states.shape[0]
|
217 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
218 |
+
query_states, key_states, value_states, attention_mask, query_length
|
219 |
+
)
|
220 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
221 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
222 |
+
|
223 |
+
attn_output_unpad = flash_attn_varlen_func(
|
224 |
+
query_states,
|
225 |
+
key_states,
|
226 |
+
value_states,
|
227 |
+
cu_seqlens_q=cu_seqlens_q,
|
228 |
+
cu_seqlens_k=cu_seqlens_k,
|
229 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
230 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
231 |
+
dropout_p=dropout,
|
232 |
+
softmax_scale=softmax_scale,
|
233 |
+
causal=self.is_causal,
|
234 |
+
)
|
235 |
+
|
236 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
237 |
+
else:
|
238 |
+
attn_output = flash_attn_func(
|
239 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
|
240 |
+
)
|
241 |
+
|
242 |
+
return attn_output
|
243 |
+
|
244 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
245 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
246 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
247 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
248 |
+
|
249 |
+
key_layer = index_first_axis(
|
250 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
251 |
+
)
|
252 |
+
value_layer = index_first_axis(
|
253 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
254 |
+
)
|
255 |
+
if query_length == kv_seq_len:
|
256 |
+
query_layer = index_first_axis(
|
257 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
258 |
+
)
|
259 |
+
cu_seqlens_q = cu_seqlens_k
|
260 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
261 |
+
indices_q = indices_k
|
262 |
+
elif query_length == 1:
|
263 |
+
max_seqlen_in_batch_q = 1
|
264 |
+
# There is a memcpy here, that is very bad.
|
265 |
+
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=query_layer.device)
|
266 |
+
indices_q = cu_seqlens_q[:-1]
|
267 |
+
query_layer = query_layer.squeeze(1)
|
268 |
+
else:
|
269 |
+
# The :q_len slice assumes right padding.
|
270 |
+
attention_mask = attention_mask[:, :query_length]
|
271 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
272 |
+
|
273 |
+
return (
|
274 |
+
query_layer,
|
275 |
+
key_layer,
|
276 |
+
value_layer,
|
277 |
+
indices_q,
|
278 |
+
(cu_seqlens_q, cu_seqlens_k),
|
279 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
280 |
+
)
|
281 |
+
|
282 |
+
|
283 |
+
class SeaLMMCLIPEncoderLayer(CLIPEncoderLayer):
|
284 |
+
def __init__(self, config: CLIPConfig):
|
285 |
+
super(CLIPEncoderLayer, self).__init__()
|
286 |
+
self.embed_dim = config.hidden_size
|
287 |
+
# self.self_attn = LlavaCLIPFlashAttention(config)
|
288 |
+
if is_flash_attn_greater_or_equal_2_10():
|
289 |
+
self.self_attn = CLIPFlashAttention2(config)
|
290 |
+
else:
|
291 |
+
self.self_attn = CLIPAttention(config)
|
292 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
293 |
+
self.mlp = CLIPMLP(config)
|
294 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
295 |
+
|
296 |
+
|
297 |
+
class SeaLMMCLIPEncoder(CLIPEncoder):
|
298 |
+
"""
|
299 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
300 |
+
[`CLIPEncoderLayer`].
|
301 |
+
|
302 |
+
Args:
|
303 |
+
config: CLIPConfig
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self, config: CLIPConfig):
|
307 |
+
super(CLIPEncoder, self).__init__()
|
308 |
+
self.config = config
|
309 |
+
self.layers = nn.ModuleList([SeaLMMCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
310 |
+
self.gradient_checkpointing = False
|
311 |
+
|
312 |
+
def forward(
|
313 |
+
self,
|
314 |
+
inputs_embeds,
|
315 |
+
attention_mask: Optional[torch.Tensor] = None,
|
316 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
317 |
+
output_attentions: Optional[bool] = None,
|
318 |
+
output_hidden_states: Optional[bool] = None,
|
319 |
+
return_dict: Optional[bool] = None,
|
320 |
+
) -> Union[Tuple, BaseModelOutput]:
|
321 |
+
r"""
|
322 |
+
Args:
|
323 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
324 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
325 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
326 |
+
than the model's internal embedding lookup matrix.
|
327 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
328 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
329 |
+
|
330 |
+
- 1 for tokens that are **not masked**,
|
331 |
+
- 0 for tokens that are **masked**.
|
332 |
+
|
333 |
+
[What are attention masks?](../glossary#attention-mask)
|
334 |
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
335 |
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
336 |
+
|
337 |
+
- 1 for tokens that are **not masked**,
|
338 |
+
- 0 for tokens that are **masked**.
|
339 |
+
|
340 |
+
[What are attention masks?](../glossary#attention-mask)
|
341 |
+
output_attentions (`bool`, *optional*):
|
342 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
343 |
+
returned tensors for more detail.
|
344 |
+
output_hidden_states (`bool`, *optional*):
|
345 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
346 |
+
for more detail.
|
347 |
+
return_dict (`bool`, *optional*):
|
348 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
349 |
+
"""
|
350 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
351 |
+
output_hidden_states = (
|
352 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
353 |
+
)
|
354 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
355 |
+
output_hidden_states = False
|
356 |
+
output_attentions = False
|
357 |
+
# return_dict = False
|
358 |
+
|
359 |
+
encoder_states = () if output_hidden_states else None
|
360 |
+
all_attentions = () if output_attentions else None
|
361 |
+
|
362 |
+
hidden_states = inputs_embeds
|
363 |
+
for idx, encoder_layer in enumerate(self.layers):
|
364 |
+
if output_hidden_states:
|
365 |
+
encoder_states = encoder_states + (hidden_states,)
|
366 |
+
# if self.gradient_checkpointing and self.training:
|
367 |
+
# layer_outputs = self._gradient_checkpointing_func(
|
368 |
+
# encoder_layer.__call__,
|
369 |
+
# hidden_states,
|
370 |
+
# attention_mask,
|
371 |
+
# causal_attention_mask,
|
372 |
+
# output_attentions,
|
373 |
+
# )
|
374 |
+
# else:
|
375 |
+
# ! enforce no checkpointing here
|
376 |
+
layer_outputs = encoder_layer(
|
377 |
+
hidden_states,
|
378 |
+
attention_mask,
|
379 |
+
causal_attention_mask,
|
380 |
+
output_attentions=output_attentions,
|
381 |
+
)
|
382 |
+
|
383 |
+
hidden_states = layer_outputs[0]
|
384 |
+
|
385 |
+
if output_attentions:
|
386 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
387 |
+
|
388 |
+
if output_hidden_states:
|
389 |
+
encoder_states = encoder_states + (hidden_states,)
|
390 |
+
|
391 |
+
if not return_dict:
|
392 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
393 |
+
return BaseModelOutput(
|
394 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
395 |
+
)
|
396 |
+
|
397 |
+
|
398 |
+
class SeaLMMVisionTransformer(nn.Module):
|
399 |
+
def __init__(self, config: CLIPVisionConfig):
|
400 |
+
super().__init__()
|
401 |
+
self.config = config
|
402 |
+
embed_dim = config.hidden_size
|
403 |
+
|
404 |
+
self.embeddings = CLIPVisionEmbeddings(config)
|
405 |
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
406 |
+
# self.encoder = CLIPEncoder(config)
|
407 |
+
self.encoder = SeaLMMCLIPEncoder(config)
|
408 |
+
# self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
409 |
+
|
410 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
411 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
412 |
+
def forward(
|
413 |
+
self,
|
414 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
415 |
+
output_attentions: Optional[bool] = None,
|
416 |
+
output_hidden_states: Optional[bool] = None,
|
417 |
+
return_dict: Optional[bool] = None,
|
418 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
419 |
+
r"""
|
420 |
+
Returns:
|
421 |
+
|
422 |
+
"""
|
423 |
+
assert output_attentions is None
|
424 |
+
assert output_hidden_states is None
|
425 |
+
# assert return_dict is None
|
426 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
427 |
+
output_hidden_states = (
|
428 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
429 |
+
)
|
430 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
431 |
+
|
432 |
+
if pixel_values is None:
|
433 |
+
raise ValueError("You have to specify pixel_values")
|
434 |
+
|
435 |
+
hidden_states = self.embeddings(pixel_values)
|
436 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
437 |
+
|
438 |
+
encoder_outputs = self.encoder(
|
439 |
+
inputs_embeds=hidden_states,
|
440 |
+
output_attentions=output_attentions,
|
441 |
+
output_hidden_states=output_hidden_states,
|
442 |
+
return_dict=return_dict,
|
443 |
+
)
|
444 |
+
|
445 |
+
last_hidden_state = encoder_outputs[0]
|
446 |
+
|
447 |
+
if not return_dict:
|
448 |
+
raise ValueError(f'Not support return_dict')
|
449 |
+
|
450 |
+
return BaseModelOutputWithPooling(
|
451 |
+
last_hidden_state=last_hidden_state,
|
452 |
+
# pooler_output=pooled_output,
|
453 |
+
pooler_output=None,
|
454 |
+
hidden_states=encoder_outputs.hidden_states,
|
455 |
+
attentions=encoder_outputs.attentions,
|
456 |
+
)
|
457 |
+
|
458 |
+
|
459 |
+
@add_start_docstrings(
|
460 |
+
"""The vision model from CLIP without any head or projection on top.""",
|
461 |
+
CLIP_START_DOCSTRING,
|
462 |
+
)
|
463 |
+
class SeaLMMCLIPVisionModel(CLIPPreTrainedModel):
|
464 |
+
config_class = CLIPVisionConfig
|
465 |
+
main_input_name = "pixel_values"
|
466 |
+
_no_split_modules = ["SeaLMMCLIPEncoderLayer"]
|
467 |
+
|
468 |
+
def __init__(self, config: CLIPVisionConfig):
|
469 |
+
super().__init__(config)
|
470 |
+
self.vision_model = SeaLMMVisionTransformer(config)
|
471 |
+
# Initialize weights and apply final processing
|
472 |
+
self.post_init()
|
473 |
+
|
474 |
+
def get_input_embeddings(self) -> nn.Module:
|
475 |
+
return self.vision_model.embeddings.patch_embedding
|
476 |
+
|
477 |
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
478 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
479 |
+
def forward(
|
480 |
+
self,
|
481 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
482 |
+
output_attentions: Optional[bool] = None,
|
483 |
+
output_hidden_states: Optional[bool] = None,
|
484 |
+
return_dict: Optional[bool] = None,
|
485 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
486 |
+
r"""
|
487 |
+
Returns:
|
488 |
+
|
489 |
+
Examples:
|
490 |
+
|
491 |
+
```python
|
492 |
+
>>> from PIL import Image
|
493 |
+
>>> import requests
|
494 |
+
>>> from transformers import AutoProcessor, CLIPVisionModel
|
495 |
+
|
496 |
+
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
497 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
498 |
+
|
499 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
500 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
501 |
+
|
502 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
503 |
+
|
504 |
+
>>> outputs = model(**inputs)
|
505 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
506 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
507 |
+
```"""
|
508 |
+
# return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
509 |
+
|
510 |
+
return self.vision_model(
|
511 |
+
pixel_values=pixel_values,
|
512 |
+
output_attentions=output_attentions,
|
513 |
+
output_hidden_states=output_hidden_states,
|
514 |
+
return_dict=return_dict,
|
515 |
+
)
|
516 |
+
|
517 |
+
|
518 |
+
class SeaLMMMultiModalProjector(SeaLMMCLIPEncoder):
|
519 |
+
def __init__(self, config: SeaLMMConfig):
|
520 |
+
super(CLIPEncoder, self).__init__()
|
521 |
+
self.config = config
|
522 |
+
self.projector_num_layers = getattr(config, "projector_num_layers", 2)
|
523 |
+
self.vision_config = config.vision_config
|
524 |
+
self.num_vision_feature_layer = int(0 - config.vision_feature_layer) - 1
|
525 |
+
|
526 |
+
assert self.num_vision_feature_layer > 0
|
527 |
+
|
528 |
+
self.layers = nn.ModuleList([
|
529 |
+
# LlavaCLIPFasterEncoderLayer(self.vision_config)
|
530 |
+
SeaLMMCLIPEncoderLayer(self.vision_config)
|
531 |
+
for _ in range(self.projector_num_layers)]
|
532 |
+
)
|
533 |
+
|
534 |
+
projector_layernorm_eps = getattr(config, "projector_layernorm_eps", 1e-05)
|
535 |
+
self.projector_layernorm = nn.LayerNorm(
|
536 |
+
# len(config.vision_feature_layers) * config.vision_config.hidden_size, eps=projector_layernorm_eps
|
537 |
+
config.vision_config.hidden_size, eps=projector_layernorm_eps
|
538 |
+
)
|
539 |
+
|
540 |
+
self.linear_1 = nn.Linear(
|
541 |
+
# len(config.vision_feature_layers) * config.vision_config.hidden_size,
|
542 |
+
config.vision_config.hidden_size,
|
543 |
+
config.text_config.hidden_size,
|
544 |
+
bias=True,
|
545 |
+
)
|
546 |
+
# self.act = ACT2FN[config.projector_hidden_act]
|
547 |
+
# self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
548 |
+
|
549 |
+
self.gradient_checkpointing = False
|
550 |
+
|
551 |
+
def forward(self, hidden_states, attention_mask=None, causal_attention_mask=None):
|
552 |
+
"""
|
553 |
+
hidden_states must not be striped
|
554 |
+
"""
|
555 |
+
output_attentions = False
|
556 |
+
|
557 |
+
for idx, encoder_layer in enumerate(self.layers):
|
558 |
+
# if output_hidden_states:
|
559 |
+
# encoder_states = encoder_states + (hidden_states,)
|
560 |
+
# if self.gradient_checkpointing and self.training:
|
561 |
+
# layer_outputs = self._gradient_checkpointing_func(
|
562 |
+
# encoder_layer.__call__,
|
563 |
+
# hidden_states,
|
564 |
+
# attention_mask,
|
565 |
+
# causal_attention_mask,
|
566 |
+
# output_attentions,
|
567 |
+
# )
|
568 |
+
# else:
|
569 |
+
# ! turn off checkpointing
|
570 |
+
layer_outputs = encoder_layer(
|
571 |
+
hidden_states,
|
572 |
+
attention_mask,
|
573 |
+
causal_attention_mask,
|
574 |
+
output_attentions=output_attentions,
|
575 |
+
)
|
576 |
+
|
577 |
+
hidden_states = layer_outputs[0]
|
578 |
+
|
579 |
+
hidden_states = hidden_states[:, 1:]
|
580 |
+
|
581 |
+
hidden_states = self.projector_layernorm(hidden_states)
|
582 |
+
hidden_states = self.linear_1(hidden_states)
|
583 |
+
# hidden_states = self.act(hidden_states)
|
584 |
+
# hidden_states = self.linear_2(hidden_states)
|
585 |
+
return hidden_states
|
586 |
+
|
587 |
+
|
588 |
+
|
589 |
+
@add_start_docstrings(
|
590 |
+
"""The CLip- LLAVA model which consists of a vision backbone and a language model.""",
|
591 |
+
LLAVA_START_DOCSTRING,
|
592 |
+
)
|
593 |
+
class SeaLMMForCausalLM(LlavaPreTrainedModel):
|
594 |
+
def __init__(self, config: SeaLMMConfig, vision_tower=None, language_model=None):
|
595 |
+
super().__init__(config)
|
596 |
+
# self.vision_tower = AutoModel.from_config(config.vision_config)
|
597 |
+
# self.vision_tower = vision_tower or LlavaCLIPVisionModel(config=config.vision_config)
|
598 |
+
self.vision_tower = vision_tower or SeaLMMCLIPVisionModel(config=config.vision_config)
|
599 |
+
self.multi_modal_projector = SeaLMMMultiModalProjector(config)
|
600 |
+
self.vocab_size = config.vocab_size
|
601 |
+
self.language_model = language_model or AutoModelForCausalLM.from_config(
|
602 |
+
config.text_config, attn_implementation=config._attn_implementation
|
603 |
+
)
|
604 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
605 |
+
self.post_init()
|
606 |
+
|
607 |
+
self.freeze_vision_tower = True
|
608 |
+
|
609 |
+
def unfreeze_vision_tower(self):
|
610 |
+
logger.info(f'UNFREEZE {self.freeze_vision_tower=}')
|
611 |
+
self.freeze_vision_tower = False
|
612 |
+
|
613 |
+
def freeze_vision_tower(self):
|
614 |
+
logger.info(f'FREEZE {self.freeze_vision_tower=}')
|
615 |
+
self.freeze_vision_tower = True
|
616 |
+
|
617 |
+
@classmethod
|
618 |
+
def create_model_config_from_components(
|
619 |
+
cls,
|
620 |
+
lm_config=None,
|
621 |
+
vision_config=None,
|
622 |
+
tokenizer=None,
|
623 |
+
vision_feature_layer=None,
|
624 |
+
projector_num_layers=1,
|
625 |
+
**kwargs,
|
626 |
+
) -> SeaLMMConfig:
|
627 |
+
# self.projector_num_layers = kwargs.get("projector_num_layers", 1)
|
628 |
+
config = SeaLMMConfig(vision_config, lm_config, projector_num_layers=projector_num_layers, **kwargs)
|
629 |
+
config.vision_feature_layer = config.vision_feature_layer if vision_feature_layer is None else vision_feature_layer
|
630 |
+
|
631 |
+
if config.vision_feature_layer < 0:
|
632 |
+
config.vision_config.num_hidden_layers = config.vision_config.num_hidden_layers + config.vision_feature_layer + 1
|
633 |
+
else:
|
634 |
+
config.vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
635 |
+
|
636 |
+
if IMAGE_TOKEN not in tokenizer.get_vocab():
|
637 |
+
tokenizer.add_special_tokens({"cls_token": IMAGE_TOKEN})
|
638 |
+
|
639 |
+
config.image_token_index = tokenizer.cls_token_id
|
640 |
+
config.vocab_size = config.text_config.vocab_size
|
641 |
+
config.architectures = ["SeaLMMForCausalLM"]
|
642 |
+
return config
|
643 |
+
|
644 |
+
def get_input_embeddings(self):
|
645 |
+
return self.language_model.get_input_embeddings()
|
646 |
+
|
647 |
+
def set_input_embeddings(self, value):
|
648 |
+
self.language_model.set_input_embeddings(value)
|
649 |
+
|
650 |
+
def get_output_embeddings(self):
|
651 |
+
return self.language_model.get_output_embeddings()
|
652 |
+
|
653 |
+
def set_output_embeddings(self, new_embeddings):
|
654 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
655 |
+
|
656 |
+
def set_decoder(self, decoder):
|
657 |
+
self.language_model.set_decoder(decoder)
|
658 |
+
|
659 |
+
def get_decoder(self):
|
660 |
+
return self.language_model.get_decoder()
|
661 |
+
|
662 |
+
def tie_weights(self):
|
663 |
+
return self.language_model.tie_weights()
|
664 |
+
|
665 |
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
666 |
+
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
667 |
+
# update vocab size
|
668 |
+
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
669 |
+
self.config.vocab_size = model_embeds.num_embeddings
|
670 |
+
self.vocab_size = model_embeds.num_embeddings
|
671 |
+
return model_embeds
|
672 |
+
|
673 |
+
# @torch.no_grad
|
674 |
+
def _merge_input_ids_with_image_features(
|
675 |
+
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids, labels=None
|
676 |
+
):
|
677 |
+
"""
|
678 |
+
input_ids: [b, tlen]
|
679 |
+
input_embeds: [b, tlen, dt]
|
680 |
+
image_features: [b, ilen, ifeat, di]
|
681 |
+
labels: None or [b, tlen] --> must extend labels to input_ids,
|
682 |
+
|
683 |
+
# in input_ids, there may be image_token_index, number of image_token_index <= ilen
|
684 |
+
input_ids: [
|
685 |
+
a b c d e f X g h i j k X l m
|
686 |
+
o p q r X s t u v _ _ _ _ _ _
|
687 |
+
]
|
688 |
+
input_ids should be: [
|
689 |
+
a b c d e f X X X X X g h i j k X X X X X l m
|
690 |
+
o p q r X X X X X s t u v _ _ _ _ _ _ _ _ _ _
|
691 |
+
]
|
692 |
+
labels should be: [
|
693 |
+
a b c d e f _ _ _ _ _ g h i j k _ _ _ _ _ l m
|
694 |
+
o p q r _ _ _ _ _ s t u v _ _ _ _ _ _ _ _ _ _
|
695 |
+
]
|
696 |
+
# mask replace image onto it
|
697 |
+
|
698 |
+
# Use torch.vmap for simplicy
|
699 |
+
def sample_merge():
|
700 |
+
input_ids: [tlen]
|
701 |
+
input_embeds: [tlen, dt]
|
702 |
+
img_embeds: [ilen, ifeat, di]
|
703 |
+
e.g:
|
704 |
+
input_ids: [
|
705 |
+
a b c d e f X g h i j k X l m
|
706 |
+
]
|
707 |
+
img_embeds: [3, ifeat, id] # img_embeds has padding
|
708 |
+
|
709 |
+
|
710 |
+
"""
|
711 |
+
with torch.no_grad():
|
712 |
+
num_images, num_image_patches, embed_dim = image_features.shape
|
713 |
+
batch_size, sequence_length = input_ids.shape
|
714 |
+
# left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
715 |
+
left_padding = torch.any(attention_mask[:, 0] == 0)
|
716 |
+
# assert not left_padding or batch_size == 1
|
717 |
+
# 1. Create a mask to know where special image tokens are
|
718 |
+
special_image_token_mask = input_ids == self.config.image_token_index
|
719 |
+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
720 |
+
# Reserve for padding of num_images
|
721 |
+
total_num_special_image_tokens = torch.sum(special_image_token_mask)
|
722 |
+
assert total_num_special_image_tokens == num_images, f'{total_num_special_image_tokens=} != {num_images=} | {image_features.shape} {input_ids}'
|
723 |
+
# Compute the maximum embed dimension
|
724 |
+
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
725 |
+
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
726 |
+
|
727 |
+
# 2. Compute the positions where text should be written
|
728 |
+
# Calculate new positions for text tokens in merged image-text sequence.
|
729 |
+
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
730 |
+
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
731 |
+
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
732 |
+
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
733 |
+
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
734 |
+
if left_padding:
|
735 |
+
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
736 |
+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
737 |
+
|
738 |
+
# 3. Create the full embedding, already padded to the maximum position
|
739 |
+
final_embedding = torch.zeros(
|
740 |
+
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
741 |
+
)
|
742 |
+
final_attention_mask = torch.zeros(
|
743 |
+
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
744 |
+
)
|
745 |
+
final_labels = None
|
746 |
+
if labels is not None:
|
747 |
+
final_labels = torch.full_like(final_attention_mask, self.config.ignore_index).to(torch.long)
|
748 |
+
|
749 |
+
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
750 |
+
# set the corresponding tensors into their correct target device.
|
751 |
+
target_device = inputs_embeds.device
|
752 |
+
batch_indices, non_image_indices, text_to_overwrite = (
|
753 |
+
batch_indices.to(target_device),
|
754 |
+
non_image_indices.to(target_device),
|
755 |
+
text_to_overwrite.to(target_device),
|
756 |
+
)
|
757 |
+
attention_mask = attention_mask.to(target_device)
|
758 |
+
|
759 |
+
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
760 |
+
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
761 |
+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
762 |
+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
763 |
+
if labels is not None:
|
764 |
+
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
765 |
+
|
766 |
+
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
|
767 |
+
image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
|
768 |
+
# image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
769 |
+
if left_padding:
|
770 |
+
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
771 |
+
else:
|
772 |
+
val = torch.arange(max_embed_dim).unsqueeze(0).to(target_device).expand(batch_size, max_embed_dim) < new_token_positions[:, -1:].to(target_device)
|
773 |
+
image_to_overwrite &= val
|
774 |
+
|
775 |
+
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
776 |
+
raise ValueError(
|
777 |
+
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
778 |
+
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
779 |
+
)
|
780 |
+
|
781 |
+
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
782 |
+
final_attention_mask |= image_to_overwrite
|
783 |
+
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
784 |
+
|
785 |
+
if not left_padding:
|
786 |
+
# Making sure its the same
|
787 |
+
seq_lens = final_attention_mask.sum(-1)
|
788 |
+
for i, (mask, seq_len) in enumerate(zip(final_attention_mask, seq_lens)):
|
789 |
+
# seq_len = mask.sum(-1)
|
790 |
+
assert torch.all(mask[:seq_len] == 1), f'final 1 mask[{i}]: {seq_len} {final_attention_mask.tolist()=}'
|
791 |
+
assert torch.all(mask[seq_len:] == 0), f'final 0 mask[{i}]: {seq_len} {final_attention_mask.tolist()=}'
|
792 |
+
|
793 |
+
|
794 |
+
# if DEBUG:
|
795 |
+
# print(f'final_attention_mask=\n{final_attention_mask.tolist()}')
|
796 |
+
# print(f'text_to_overwrite=\n{text_to_overwrite.int().tolist()}')
|
797 |
+
# print(f'image_to_overwrite=\n{image_to_overwrite.int().tolist()}')
|
798 |
+
# print(f'position_ids=\n{position_ids.tolist()}')
|
799 |
+
# print(f'labels=\n{labels.tolist()}')
|
800 |
+
# print(f'final_labels=\n{final_labels.tolist()}')
|
801 |
+
|
802 |
+
return final_embedding, final_attention_mask, position_ids, final_labels
|
803 |
+
|
804 |
+
def extract_image_features(self, pixel_values, vision_feature_select_strategy=None):
|
805 |
+
vision_feature_select_strategy = (
|
806 |
+
vision_feature_select_strategy
|
807 |
+
if vision_feature_select_strategy is not None
|
808 |
+
else self.config.vision_feature_select_strategy
|
809 |
+
)
|
810 |
+
with (torch.no_grad() if self.freeze_vision_tower else nullcontext()):
|
811 |
+
image_outputs = self.vision_tower(pixel_values)
|
812 |
+
hiddent_states = image_outputs.last_hidden_state
|
813 |
+
image_features = self.multi_modal_projector(hiddent_states)
|
814 |
+
return image_features
|
815 |
+
|
816 |
+
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
|
817 |
+
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
818 |
+
def forward(
|
819 |
+
self,
|
820 |
+
input_ids: torch.LongTensor = None,
|
821 |
+
pixel_values: torch.FloatTensor = None,
|
822 |
+
attention_mask: Optional[torch.Tensor] = None,
|
823 |
+
position_ids: Optional[torch.LongTensor] = None,
|
824 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
825 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
826 |
+
vision_feature_layer: Optional[int] = None,
|
827 |
+
vision_feature_select_strategy: Optional[str] = None,
|
828 |
+
labels: Optional[torch.LongTensor] = None,
|
829 |
+
use_cache: Optional[bool] = None,
|
830 |
+
output_attentions: Optional[bool] = None,
|
831 |
+
output_hidden_states: Optional[bool] = None,
|
832 |
+
return_dict: Optional[bool] = None,
|
833 |
+
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
834 |
+
r"""
|
835 |
+
Args:
|
836 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
837 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
838 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
839 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
840 |
+
|
841 |
+
Returns:
|
842 |
+
|
843 |
+
Example:
|
844 |
+
|
845 |
+
```python
|
846 |
+
>>> from PIL import Image
|
847 |
+
>>> import requests
|
848 |
+
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
849 |
+
|
850 |
+
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
851 |
+
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
852 |
+
|
853 |
+
>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
|
854 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
855 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
856 |
+
|
857 |
+
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
858 |
+
|
859 |
+
>>> # Generate
|
860 |
+
>>> generate_ids = model.generate(**inputs, max_length=30)
|
861 |
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
862 |
+
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
|
863 |
+
```"""
|
864 |
+
|
865 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
866 |
+
output_hidden_states = (
|
867 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
868 |
+
)
|
869 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
870 |
+
vision_feature_layer = (
|
871 |
+
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
872 |
+
)
|
873 |
+
vision_feature_select_strategy = (
|
874 |
+
vision_feature_select_strategy
|
875 |
+
if vision_feature_select_strategy is not None
|
876 |
+
else self.config.vision_feature_select_strategy
|
877 |
+
)
|
878 |
+
|
879 |
+
if inputs_embeds is None:
|
880 |
+
# 1. Extra the input embeddings
|
881 |
+
for_inputs_embeds_ids = input_ids.clone()
|
882 |
+
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
|
883 |
+
# inputs_embeds = self.get_input_embeddings()(input_ids)
|
884 |
+
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
|
885 |
+
|
886 |
+
# 2. Merge text and images
|
887 |
+
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
|
888 |
+
num_images = pixel_values.size(0)
|
889 |
+
batch_size, sequence_length = input_ids.shape
|
890 |
+
special_image_token_mask = input_ids == self.config.image_token_index
|
891 |
+
# Reserve for padding of num_images
|
892 |
+
total_num_special_image_tokens = torch.sum(special_image_token_mask)
|
893 |
+
assert num_images == total_num_special_image_tokens, (
|
894 |
+
f'{num_images} < {total_num_special_image_tokens} | {special_image_token_mask}'
|
895 |
+
)
|
896 |
+
# pixel_values = pixel_values[:total_num_special_image_tokens]
|
897 |
+
|
898 |
+
# image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
899 |
+
# with (torch.no_grad() if self.freeze_vision_tower else nullcontext()):
|
900 |
+
# image_outputs = self.vision_tower(pixel_values)
|
901 |
+
# # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
|
902 |
+
# # selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
903 |
+
# selected_image_feature = image_outputs.last_hidden_state
|
904 |
+
|
905 |
+
# if vision_feature_select_strategy == "default":
|
906 |
+
# selected_image_feature = selected_image_feature[:, 1:]
|
907 |
+
# elif vision_feature_select_strategy == "full":
|
908 |
+
# selected_image_feature = selected_image_feature
|
909 |
+
# else:
|
910 |
+
# raise ValueError(
|
911 |
+
# f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
912 |
+
# )
|
913 |
+
|
914 |
+
# image_features = self.multi_modal_projector(selected_image_feature)
|
915 |
+
# print(f"{pixel_values.size()=}")
|
916 |
+
# ! extract_image_features will handle all image features extraction
|
917 |
+
image_features = self.extract_image_features(pixel_values)
|
918 |
+
# if DEBUG:
|
919 |
+
# image_features = image_features[:, :3]
|
920 |
+
|
921 |
+
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
|
922 |
+
image_features, inputs_embeds, input_ids, attention_mask, position_ids,
|
923 |
+
labels=labels
|
924 |
+
)
|
925 |
+
# if labels is None:
|
926 |
+
# # ! this is wrong!
|
927 |
+
# labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
|
928 |
+
# print(inputs_embeds.size())
|
929 |
+
|
930 |
+
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
|
931 |
+
# there is no images
|
932 |
+
pass
|
933 |
+
else:
|
934 |
+
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
935 |
+
# generation with cache
|
936 |
+
# ! (phi) why do we need to do this?
|
937 |
+
# if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
938 |
+
# # ! it can possible the bug because if mistral, from the first layer_key like this
|
939 |
+
# # ! MUST UNDERSTAND and fix error
|
940 |
+
# # Retrieve the first layer to inspect the logits and mask out the hidden states
|
941 |
+
# # that are set to 0
|
942 |
+
# first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
|
943 |
+
# batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
|
944 |
+
# # Get the target length
|
945 |
+
# target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
946 |
+
|
947 |
+
# extended_attention_mask = torch.ones(
|
948 |
+
# (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
949 |
+
# dtype=attention_mask.dtype,
|
950 |
+
# device=attention_mask.device,
|
951 |
+
# )
|
952 |
+
# # print(f'{extended_attention_mask.shape} | {batch_index=} | {non_attended_tokens=}')
|
953 |
+
|
954 |
+
# # Zero-out the places where we don't need to attend
|
955 |
+
# extended_attention_mask[batch_index, non_attended_tokens] = 0
|
956 |
+
|
957 |
+
# attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
958 |
+
# position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
959 |
+
|
960 |
+
# ! fix: https://github.com/huggingface/transformers/blob/c90268de7560c3fef21a927e0bfcf2b611a8711e/src/transformers/models/llava/modeling_llava.py
|
961 |
+
# https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
962 |
+
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
963 |
+
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
964 |
+
# that are set to 0
|
965 |
+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
966 |
+
|
967 |
+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
968 |
+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
969 |
+
|
970 |
+
# Get the target length
|
971 |
+
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
972 |
+
|
973 |
+
extended_attention_mask = torch.ones(
|
974 |
+
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
975 |
+
dtype=attention_mask.dtype,
|
976 |
+
device=attention_mask.device,
|
977 |
+
)
|
978 |
+
|
979 |
+
# Filter out only the tokens that can be un-attended, this can happen
|
980 |
+
# in the case one uses Llava + Fused modules where the cache on the
|
981 |
+
# first iteration is already big enough, or if one passes custom cache
|
982 |
+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
983 |
+
new_batch_index = batch_index[valid_indices]
|
984 |
+
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
985 |
+
|
986 |
+
# Zero-out the places where we don't need to attend
|
987 |
+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
988 |
+
|
989 |
+
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
990 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
991 |
+
|
992 |
+
|
993 |
+
outputs = self.language_model(
|
994 |
+
attention_mask=attention_mask,
|
995 |
+
position_ids=position_ids,
|
996 |
+
past_key_values=past_key_values,
|
997 |
+
inputs_embeds=inputs_embeds,
|
998 |
+
use_cache=use_cache,
|
999 |
+
output_attentions=output_attentions,
|
1000 |
+
output_hidden_states=output_hidden_states,
|
1001 |
+
return_dict=return_dict,
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
logits = outputs[0]
|
1005 |
+
|
1006 |
+
loss = None
|
1007 |
+
if labels is not None:
|
1008 |
+
# Shift so that tokens < n predict n
|
1009 |
+
if attention_mask is not None:
|
1010 |
+
shift_attention_mask = attention_mask[..., 1:]
|
1011 |
+
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
|
1012 |
+
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
|
1013 |
+
else:
|
1014 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
1015 |
+
shift_labels = labels[..., 1:].contiguous()
|
1016 |
+
# Flatten the tokens
|
1017 |
+
loss_fct = nn.CrossEntropyLoss()
|
1018 |
+
loss = loss_fct(
|
1019 |
+
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
if not return_dict:
|
1023 |
+
output = (logits,) + outputs[1:]
|
1024 |
+
return (loss,) + output if loss is not None else output
|
1025 |
+
|
1026 |
+
return LlavaCausalLMOutputWithPast(
|
1027 |
+
loss=loss,
|
1028 |
+
logits=logits,
|
1029 |
+
past_key_values=outputs.past_key_values,
|
1030 |
+
hidden_states=outputs.hidden_states,
|
1031 |
+
attentions=outputs.attentions,
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
def prepare_inputs_for_generation(
|
1035 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
|
1036 |
+
):
|
1037 |
+
if past_key_values is not None:
|
1038 |
+
if isinstance(past_key_values, Cache):
|
1039 |
+
cache_length = past_key_values.get_seq_length()
|
1040 |
+
past_length = past_key_values.seen_tokens
|
1041 |
+
else:
|
1042 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
1043 |
+
|
1044 |
+
# Keep only the unprocessed tokens:
|
1045 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1046 |
+
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
1047 |
+
# input)
|
1048 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1049 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1050 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
1051 |
+
# input_ids based on the past_length.
|
1052 |
+
elif past_length < input_ids.shape[1]:
|
1053 |
+
input_ids = input_ids[:, past_length:]
|
1054 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1055 |
+
elif self.config.image_token_index in input_ids:
|
1056 |
+
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
1057 |
+
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
1058 |
+
# older attention values, as their corresponding values are not part of the input.
|
1059 |
+
if cache_length < past_length and attention_mask is not None:
|
1060 |
+
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
1061 |
+
|
1062 |
+
position_ids = kwargs.get("position_ids", None)
|
1063 |
+
if attention_mask is not None and position_ids is None:
|
1064 |
+
# create position_ids on the fly for batch generation
|
1065 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1066 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1067 |
+
if past_key_values:
|
1068 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1069 |
+
|
1070 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1071 |
+
if inputs_embeds is not None and past_key_values is None:
|
1072 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1073 |
+
else:
|
1074 |
+
model_inputs = {"input_ids": input_ids}
|
1075 |
+
|
1076 |
+
model_inputs.update(
|
1077 |
+
{
|
1078 |
+
"position_ids": position_ids,
|
1079 |
+
"past_key_values": past_key_values,
|
1080 |
+
"use_cache": kwargs.get("use_cache"),
|
1081 |
+
"attention_mask": attention_mask,
|
1082 |
+
"pixel_values": pixel_values,
|
1083 |
+
}
|
1084 |
+
)
|
1085 |
+
return model_inputs
|
1086 |
+
|
1087 |
+
def _reorder_cache(self, *args, **kwargs):
|
1088 |
+
return self.language_model._reorder_cache(*args, **kwargs)
|
1089 |
+
|
1090 |
+
|
1091 |
+
|
multipurpose_chatbot/engines/sealmmm_engine.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from transformers_stream_generator import init_stream_support
|
2 |
+
# init_stream_support()
|
3 |
+
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import argparse
|
7 |
+
import torch
|
8 |
+
import gradio as gr
|
9 |
+
from typing import Any, Iterator
|
10 |
+
from typing import Iterator, List, Optional, Tuple
|
11 |
+
import filelock
|
12 |
+
import glob
|
13 |
+
import json
|
14 |
+
import time
|
15 |
+
from gradio.routes import Request
|
16 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
17 |
+
from gradio.helpers import special_args
|
18 |
+
import anyio
|
19 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
20 |
+
|
21 |
+
from gradio_client.documentation import document, set_documentation_group
|
22 |
+
|
23 |
+
from typing import List, Optional, Union, Dict, Tuple
|
24 |
+
from tqdm.auto import tqdm
|
25 |
+
from huggingface_hub import snapshot_download
|
26 |
+
|
27 |
+
from gradio.components import Button
|
28 |
+
from gradio.events import Dependency, EventListenerMethod
|
29 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
30 |
+
import types
|
31 |
+
import sys
|
32 |
+
from .base_engine import BaseEngine
|
33 |
+
from .transformers_engine import TransformersEngine, NewGenerationMixin
|
34 |
+
|
35 |
+
from ..configs import (
|
36 |
+
STREAM_CHECK_MULTIPLE,
|
37 |
+
STREAM_YIELD_MULTIPLE,
|
38 |
+
)
|
39 |
+
|
40 |
+
CODE_PATH = os.environ.get("CODE_PATH", "")
|
41 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "")
|
42 |
+
|
43 |
+
IMAGE_TOKEN = "[IMAGE]<|image|>[/IMAGE]"
|
44 |
+
|
45 |
+
IMAGE_LENGTH = 576
|
46 |
+
MAX_PACHES = 1
|
47 |
+
|
48 |
+
|
49 |
+
BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
|
50 |
+
BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
|
51 |
+
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
|
52 |
+
KEYWORDS = os.environ.get("KEYWORDS", "").strip()
|
53 |
+
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
|
54 |
+
KEYWORDS = [x.lower() for x in KEYWORDS]
|
55 |
+
|
56 |
+
LANG_BLOCK_MESSAGE = """Unsupported language."""
|
57 |
+
|
58 |
+
KEYWORD_BLOCK_MESSAGE = "Invalid request."
|
59 |
+
|
60 |
+
|
61 |
+
def _detect_lang(text):
|
62 |
+
# Disable language that may have safety risk
|
63 |
+
from langdetect import detect as detect_lang
|
64 |
+
dlang = None
|
65 |
+
try:
|
66 |
+
dlang = detect_lang(text)
|
67 |
+
except Exception as e:
|
68 |
+
if "No features in text." in str(e):
|
69 |
+
return "en"
|
70 |
+
else:
|
71 |
+
return "zh"
|
72 |
+
return dlang
|
73 |
+
|
74 |
+
|
75 |
+
def block_lang(
|
76 |
+
message: str,
|
77 |
+
history: List[Tuple[str, str]] = None,
|
78 |
+
) -> str:
|
79 |
+
# relieve history base block
|
80 |
+
if len(BLOCK_LANGS) == 0:
|
81 |
+
return False
|
82 |
+
|
83 |
+
if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
|
84 |
+
return True
|
85 |
+
else:
|
86 |
+
_lang = _detect_lang(message)
|
87 |
+
if _lang in BLOCK_LANGS:
|
88 |
+
# print(f'Detect blocked {_lang}: {message}')
|
89 |
+
return True
|
90 |
+
else:
|
91 |
+
return False
|
92 |
+
|
93 |
+
def safety_check(text, history=None, ) -> Optional[str]:
|
94 |
+
"""
|
95 |
+
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
96 |
+
This provides an additional security measure to enhance safety and compliance with local regulations.
|
97 |
+
"""
|
98 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
99 |
+
return KEYWORD_BLOCK_MESSAGE
|
100 |
+
|
101 |
+
if len(BLOCK_LANGS) > 0:
|
102 |
+
if block_lang(text, history):
|
103 |
+
return LANG_BLOCK_MESSAGE
|
104 |
+
|
105 |
+
return None
|
106 |
+
|
107 |
+
|
108 |
+
def safety_check_conversation_string(text, delimiter=None) -> Optional[str]:
|
109 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
110 |
+
return KEYWORD_BLOCK_MESSAGE
|
111 |
+
if len(BLOCK_LANGS) > 0:
|
112 |
+
import re
|
113 |
+
delimiter = delimiter or (r"</s><\|im_start\|>user\n", r"</s><\|im_start\|>assistant\n", r"<\|im_start\|>system\n")
|
114 |
+
turns = re.split(r"|".join(delimiter), text)
|
115 |
+
turns = [t for t in turns if t.strip() != '']
|
116 |
+
for t in turns:
|
117 |
+
if block_lang(t):
|
118 |
+
return LANG_BLOCK_MESSAGE
|
119 |
+
return None
|
120 |
+
|
121 |
+
|
122 |
+
def is_check_safety():
|
123 |
+
return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0
|
124 |
+
|
125 |
+
|
126 |
+
def safety_check_conversation(conversation) -> Optional[str]:
|
127 |
+
"""
|
128 |
+
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
129 |
+
This provides an additional security measure to enhance safety and compliance with local regulations.
|
130 |
+
"""
|
131 |
+
texts = [c['content'] for c in conversation]
|
132 |
+
for text in texts:
|
133 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
134 |
+
return KEYWORD_BLOCK_MESSAGE
|
135 |
+
|
136 |
+
if len(BLOCK_LANGS) > 0:
|
137 |
+
if block_lang(text):
|
138 |
+
return LANG_BLOCK_MESSAGE
|
139 |
+
return None
|
140 |
+
|
141 |
+
|
142 |
+
class SeaLMMMv0Engine(TransformersEngine):
|
143 |
+
|
144 |
+
@property
|
145 |
+
def image_token(self):
|
146 |
+
return IMAGE_TOKEN
|
147 |
+
|
148 |
+
@property
|
149 |
+
def max_position_embeddings(self) -> int:
|
150 |
+
return self._model.config.max_position_embeddings
|
151 |
+
|
152 |
+
@property
|
153 |
+
def tokenizer(self):
|
154 |
+
return self._tokenizer
|
155 |
+
|
156 |
+
@property
|
157 |
+
def processor(self):
|
158 |
+
return self._processor
|
159 |
+
|
160 |
+
def load_model(self):
|
161 |
+
from transformers import AutoProcessor
|
162 |
+
import sys
|
163 |
+
# caution: path[0] is reserved for script path (or '' in REPL)
|
164 |
+
# sys.path.append(CODE_PATH)
|
165 |
+
|
166 |
+
# from examples.llm.src.models.sealmm.modeling_sealmm import (
|
167 |
+
# SeaLMMForCausalLM
|
168 |
+
# )
|
169 |
+
from modeling_sealmm import (SeaLMMForCausalLM, )
|
170 |
+
model_path = MODEL_PATH
|
171 |
+
print(f'Loading model from {model_path}')
|
172 |
+
|
173 |
+
print(f'model_path={model_path}')
|
174 |
+
if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"):
|
175 |
+
os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin")
|
176 |
+
|
177 |
+
self._processor = AutoProcessor.from_pretrained(model_path)
|
178 |
+
self._model = SeaLMMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval()
|
179 |
+
|
180 |
+
self._model.sample_old = self._model.sample
|
181 |
+
self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
182 |
+
|
183 |
+
self._tokenizer = self._processor.tokenizer
|
184 |
+
print(self._model)
|
185 |
+
print(f"{self.max_position_embeddings=}")
|
186 |
+
|
187 |
+
def get_multimodal_tokens(self, full_prompt, image_paths=None):
|
188 |
+
num_tokens = len(self.tokenizer.encode(full_prompt))
|
189 |
+
for image_path in image_paths:
|
190 |
+
num_tokens += IMAGE_LENGTH * MAX_PACHES
|
191 |
+
return num_tokens
|
192 |
+
|
193 |
+
def maybe_raise_safety(self, message, gen_index=-1):
|
194 |
+
if is_check_safety():
|
195 |
+
if gen_index < 0:
|
196 |
+
message_safety = safety_check_conversation_string(message)
|
197 |
+
if message_safety is not None:
|
198 |
+
raise gr.Error(message_safety)
|
199 |
+
else:
|
200 |
+
if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0:
|
201 |
+
message_safety = safety_check_conversation_string(message)
|
202 |
+
if message_safety is not None:
|
203 |
+
raise gr.Error(message_safety)
|
204 |
+
|
205 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
206 |
+
from transformers.generation.utils import GenerationConfig
|
207 |
+
from PIL import Image
|
208 |
+
image_paths = kwargs.get("image_paths", None)
|
209 |
+
image_paths = image_paths or []
|
210 |
+
|
211 |
+
images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None
|
212 |
+
|
213 |
+
with torch.no_grad():
|
214 |
+
inputs = self.processor(prompt, images, return_tensors='pt')
|
215 |
+
# inputs = {k: v.to("cuda", torch.bfloat16) for k, v in inputs.items() if v is not None}
|
216 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items() if v is not None}
|
217 |
+
num_tokens = self.get_multimodal_tokens(prompt, image_paths)
|
218 |
+
# non-streaming generation
|
219 |
+
# output = self._model.generate(
|
220 |
+
# **inputs,
|
221 |
+
# do_sample=True,
|
222 |
+
# temperature=temperature,
|
223 |
+
# max_new_tokens=max_tokens,
|
224 |
+
# pad_token_id=self.processor.tokenizer.pad_token_id,
|
225 |
+
# )
|
226 |
+
# # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True)
|
227 |
+
# full_output_text = self.processor.decode(output[0], skip_special_tokens=True)
|
228 |
+
# response = full_output_text.split("<|im_start|>assistant\n")[-1]
|
229 |
+
# num_tokens = self.get_multimodal_tokens(prompt + response, image_paths)
|
230 |
+
# print(prompt)
|
231 |
+
# print(response)
|
232 |
+
# print(num_tokens)
|
233 |
+
# yield response, num_tokens
|
234 |
+
|
235 |
+
# if i % 4 == 0 and i > 1:
|
236 |
+
# message_safety = safety_check(response)
|
237 |
+
# if message_safety is not None:
|
238 |
+
# history = undo_history(history)
|
239 |
+
# yield history, "", None
|
240 |
+
# raise gr.Error(message_safety)
|
241 |
+
self.maybe_raise_safety(prompt)
|
242 |
+
|
243 |
+
# # ! streaming
|
244 |
+
generator = self._model.generate(
|
245 |
+
**inputs,
|
246 |
+
do_sample=True,
|
247 |
+
temperature=temperature,
|
248 |
+
max_new_tokens=max_tokens,
|
249 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
250 |
+
)
|
251 |
+
|
252 |
+
out_tokens = []
|
253 |
+
response = None
|
254 |
+
for index, token in enumerate(generator):
|
255 |
+
out_tokens.append(token.item())
|
256 |
+
response = self.processor.tokenizer.decode(out_tokens)
|
257 |
+
|
258 |
+
self.maybe_raise_safety(response, gen_index=index)
|
259 |
+
yield response, num_tokens
|
260 |
+
|
261 |
+
del generator
|
262 |
+
|
263 |
+
if response is not None:
|
264 |
+
self.maybe_raise_safety(prompt)
|
265 |
+
|
266 |
+
full_text = prompt + response
|
267 |
+
num_tokens = self.get_multimodal_tokens(full_text, image_paths)
|
268 |
+
yield response, num_tokens
|
269 |
+
|
multipurpose_chatbot/engines/transformers_engine.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
from typing import Any, Iterator
|
8 |
+
from typing import Iterator, List, Optional, Tuple
|
9 |
+
import filelock
|
10 |
+
import glob
|
11 |
+
import json
|
12 |
+
import time
|
13 |
+
from gradio.routes import Request
|
14 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
15 |
+
from gradio.helpers import special_args
|
16 |
+
import anyio
|
17 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
18 |
+
|
19 |
+
from gradio_client.documentation import document, set_documentation_group
|
20 |
+
|
21 |
+
from typing import List, Optional, Union, Dict, Tuple
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
from huggingface_hub import snapshot_download
|
24 |
+
import types
|
25 |
+
|
26 |
+
from gradio.components import Button
|
27 |
+
from gradio.events import Dependency, EventListenerMethod
|
28 |
+
|
29 |
+
from .base_engine import BaseEngine
|
30 |
+
|
31 |
+
# ! Remember to use static cache
|
32 |
+
|
33 |
+
from transformers import (
|
34 |
+
GenerationConfig,
|
35 |
+
GenerationMixin,
|
36 |
+
LogitsProcessorList,
|
37 |
+
StoppingCriteriaList,
|
38 |
+
DisjunctiveConstraint,
|
39 |
+
BeamSearchScorer,
|
40 |
+
PhrasalConstraint,
|
41 |
+
ConstrainedBeamSearchScorer,
|
42 |
+
PreTrainedModel,
|
43 |
+
)
|
44 |
+
import numpy as np
|
45 |
+
import random
|
46 |
+
import warnings
|
47 |
+
import inspect
|
48 |
+
from transformers.generation.utils import GenerateOutput, SampleOutput, logger
|
49 |
+
import torch
|
50 |
+
from typing import Callable, List, Optional, Union
|
51 |
+
from torch import nn
|
52 |
+
import torch.distributed as dist
|
53 |
+
import copy
|
54 |
+
|
55 |
+
from ..configs import (
|
56 |
+
MODEL_PATH,
|
57 |
+
DTYPE,
|
58 |
+
DEVICE,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def setup_seed(seed):
|
63 |
+
if seed == -1:
|
64 |
+
return
|
65 |
+
torch.manual_seed(seed)
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
torch.cuda.manual_seed_all(seed)
|
68 |
+
np.random.seed(seed)
|
69 |
+
random.seed(seed)
|
70 |
+
torch.backends.cudnn.deterministic = True
|
71 |
+
|
72 |
+
|
73 |
+
class NewGenerationMixin(GenerationMixin):
|
74 |
+
"""
|
75 |
+
Allow generator sampling
|
76 |
+
|
77 |
+
"""
|
78 |
+
|
79 |
+
# ! Copy from transformers.generation.utils -> GenerationMixin
|
80 |
+
# Change sample function to sample_stream
|
81 |
+
@torch.no_grad()
|
82 |
+
def sample_stream(
|
83 |
+
self,
|
84 |
+
input_ids: torch.LongTensor,
|
85 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
86 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
87 |
+
logits_warper: Optional[LogitsProcessorList] = None,
|
88 |
+
max_length: Optional[int] = None,
|
89 |
+
pad_token_id: Optional[int] = None,
|
90 |
+
eos_token_id: Optional[Union[int, List[int]]] = None,
|
91 |
+
output_attentions: Optional[bool] = None,
|
92 |
+
output_hidden_states: Optional[bool] = None,
|
93 |
+
output_scores: Optional[bool] = None,
|
94 |
+
output_logits: Optional[bool] = None,
|
95 |
+
return_dict_in_generate: Optional[bool] = None,
|
96 |
+
synced_gpus: bool = False,
|
97 |
+
streamer: Optional["BaseStreamer"] = None,
|
98 |
+
**model_kwargs,
|
99 |
+
):
|
100 |
+
r"""
|
101 |
+
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
102 |
+
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
103 |
+
|
104 |
+
<Tip warning={true}>
|
105 |
+
|
106 |
+
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
|
107 |
+
For an overview of generation strategies and code examples, check the [following
|
108 |
+
guide](../generation_strategies).
|
109 |
+
|
110 |
+
</Tip>
|
111 |
+
|
112 |
+
Parameters:
|
113 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
114 |
+
The sequence used as a prompt for the generation.
|
115 |
+
logits_processor (`LogitsProcessorList`, *optional*):
|
116 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
117 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
118 |
+
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
119 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
120 |
+
used to tell if the generation loop should stop.
|
121 |
+
logits_warper (`LogitsProcessorList`, *optional*):
|
122 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
123 |
+
to warp the prediction score distribution of the language modeling head applied before multinomial
|
124 |
+
sampling at each generation step.
|
125 |
+
max_length (`int`, *optional*, defaults to 20):
|
126 |
+
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
|
127 |
+
tokens. The maximum length of the sequence to be generated.
|
128 |
+
pad_token_id (`int`, *optional*):
|
129 |
+
The id of the *padding* token.
|
130 |
+
eos_token_id (`Union[int, List[int]]`, *optional*):
|
131 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
132 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
133 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
134 |
+
returned tensors for more details.
|
135 |
+
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
136 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
137 |
+
for more details.
|
138 |
+
output_scores (`bool`, *optional*, defaults to `False`):
|
139 |
+
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
140 |
+
output_logits (`bool`, *optional*, defaults to `False`):
|
141 |
+
Whether or not to return the raw prediction logit scores. See `logits` under returned tensors for
|
142 |
+
more details.
|
143 |
+
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
144 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
145 |
+
synced_gpus (`bool`, *optional*, defaults to `False`):
|
146 |
+
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
147 |
+
streamer (`BaseStreamer`, *optional*):
|
148 |
+
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
149 |
+
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
150 |
+
model_kwargs:
|
151 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
152 |
+
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
153 |
+
|
154 |
+
Return:
|
155 |
+
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
|
156 |
+
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
157 |
+
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
158 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
159 |
+
`model.config.is_encoder_decoder=True`.
|
160 |
+
|
161 |
+
Examples:
|
162 |
+
|
163 |
+
```python
|
164 |
+
>>> from transformers import (
|
165 |
+
... AutoTokenizer,
|
166 |
+
... AutoModelForCausalLM,
|
167 |
+
... LogitsProcessorList,
|
168 |
+
... MinLengthLogitsProcessor,
|
169 |
+
... TopKLogitsWarper,
|
170 |
+
... TemperatureLogitsWarper,
|
171 |
+
... StoppingCriteriaList,
|
172 |
+
... MaxLengthCriteria,
|
173 |
+
... )
|
174 |
+
>>> import torch
|
175 |
+
|
176 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
177 |
+
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
178 |
+
|
179 |
+
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
|
180 |
+
>>> model.config.pad_token_id = model.config.eos_token_id
|
181 |
+
>>> model.generation_config.pad_token_id = model.config.eos_token_id
|
182 |
+
|
183 |
+
>>> input_prompt = "Today is a beautiful day, and"
|
184 |
+
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
|
185 |
+
|
186 |
+
>>> # instantiate logits processors
|
187 |
+
>>> logits_processor = LogitsProcessorList(
|
188 |
+
... [
|
189 |
+
... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
|
190 |
+
... ]
|
191 |
+
... )
|
192 |
+
>>> # instantiate logits processors
|
193 |
+
>>> logits_warper = LogitsProcessorList(
|
194 |
+
... [
|
195 |
+
... TopKLogitsWarper(50),
|
196 |
+
... TemperatureLogitsWarper(0.7),
|
197 |
+
... ]
|
198 |
+
... )
|
199 |
+
|
200 |
+
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
201 |
+
|
202 |
+
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
|
203 |
+
>>> outputs = model.sample(
|
204 |
+
... input_ids,
|
205 |
+
... logits_processor=logits_processor,
|
206 |
+
... logits_warper=logits_warper,
|
207 |
+
... stopping_criteria=stopping_criteria,
|
208 |
+
... )
|
209 |
+
|
210 |
+
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
211 |
+
['Today is a beautiful day, and we must do everything possible to make it a day of celebration.']
|
212 |
+
```"""
|
213 |
+
# init values
|
214 |
+
from transformers.generation.utils import (
|
215 |
+
validate_stopping_criteria, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
|
216 |
+
)
|
217 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
218 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
219 |
+
if max_length is not None:
|
220 |
+
warnings.warn(
|
221 |
+
"`max_length` is deprecated in this function, use"
|
222 |
+
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
223 |
+
UserWarning,
|
224 |
+
)
|
225 |
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
226 |
+
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
227 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
228 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
229 |
+
if isinstance(eos_token_id, int):
|
230 |
+
eos_token_id = [eos_token_id]
|
231 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
232 |
+
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
233 |
+
output_logits = output_logits if output_logits is not None else self.generation_config.output_logits
|
234 |
+
output_attentions = (
|
235 |
+
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
236 |
+
)
|
237 |
+
output_hidden_states = (
|
238 |
+
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
239 |
+
)
|
240 |
+
return_dict_in_generate = (
|
241 |
+
return_dict_in_generate
|
242 |
+
if return_dict_in_generate is not None
|
243 |
+
else self.generation_config.return_dict_in_generate
|
244 |
+
)
|
245 |
+
|
246 |
+
# init attention / hidden states / scores tuples
|
247 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
248 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
249 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
250 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
251 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
252 |
+
|
253 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
254 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
255 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
256 |
+
encoder_hidden_states = (
|
257 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
258 |
+
)
|
259 |
+
|
260 |
+
# keep track of which sequences are already finished
|
261 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
262 |
+
|
263 |
+
this_peer_finished = False # used by synced_gpus only
|
264 |
+
# auto-regressive generation
|
265 |
+
while True:
|
266 |
+
if synced_gpus:
|
267 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
268 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
269 |
+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
270 |
+
# send 0.0 if we finished, 1.0 otherwise
|
271 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
272 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
273 |
+
if this_peer_finished_flag.item() == 0.0:
|
274 |
+
break
|
275 |
+
|
276 |
+
# prepare model inputs
|
277 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
278 |
+
|
279 |
+
# forward pass to get next token
|
280 |
+
outputs = self(
|
281 |
+
**model_inputs,
|
282 |
+
return_dict=True,
|
283 |
+
output_attentions=output_attentions,
|
284 |
+
output_hidden_states=output_hidden_states,
|
285 |
+
)
|
286 |
+
|
287 |
+
if synced_gpus and this_peer_finished:
|
288 |
+
continue # don't waste resources running the code we don't need
|
289 |
+
|
290 |
+
next_token_logits = outputs.logits[:, -1, :]
|
291 |
+
|
292 |
+
# pre-process distribution
|
293 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
294 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
295 |
+
|
296 |
+
# Store scores, attentions and hidden_states when required
|
297 |
+
if return_dict_in_generate:
|
298 |
+
if output_scores:
|
299 |
+
scores += (next_token_scores,)
|
300 |
+
if output_logits:
|
301 |
+
raw_logits += (next_token_logits,)
|
302 |
+
if output_attentions:
|
303 |
+
decoder_attentions += (
|
304 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
305 |
+
)
|
306 |
+
if self.config.is_encoder_decoder:
|
307 |
+
cross_attentions += (outputs.cross_attentions,)
|
308 |
+
|
309 |
+
if output_hidden_states:
|
310 |
+
decoder_hidden_states += (
|
311 |
+
(outputs.decoder_hidden_states,)
|
312 |
+
if self.config.is_encoder_decoder
|
313 |
+
else (outputs.hidden_states,)
|
314 |
+
)
|
315 |
+
|
316 |
+
# sample
|
317 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
318 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
319 |
+
|
320 |
+
# finished sentences should have their next token be a padding token
|
321 |
+
if eos_token_id is not None:
|
322 |
+
if pad_token_id is None:
|
323 |
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
324 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
325 |
+
|
326 |
+
yield next_tokens.cpu()
|
327 |
+
|
328 |
+
# update generated ids, model inputs, and length for next step
|
329 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
330 |
+
if streamer is not None:
|
331 |
+
streamer.put(next_tokens.cpu())
|
332 |
+
|
333 |
+
next_model_inputs = {}
|
334 |
+
if "cache_position" in model_inputs:
|
335 |
+
next_model_inputs['cache_position'] = model_inputs['cache_position']
|
336 |
+
try:
|
337 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
338 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
|
339 |
+
# model_inputs=model_inputs
|
340 |
+
model_inputs=next_model_inputs,
|
341 |
+
)
|
342 |
+
except Exception as e:
|
343 |
+
# ! some transformers version don't have model_inputs in generation
|
344 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
345 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder,
|
346 |
+
# model_inputs=model_inputs
|
347 |
+
# model_inputs=next_model_inputs,
|
348 |
+
)
|
349 |
+
|
350 |
+
# if eos_token was found in one sentence, set sentence to finished
|
351 |
+
if eos_token_id_tensor is not None:
|
352 |
+
unfinished_sequences = unfinished_sequences.mul(
|
353 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
354 |
+
)
|
355 |
+
|
356 |
+
# stop when each sentence is finished
|
357 |
+
if unfinished_sequences.max() == 0:
|
358 |
+
this_peer_finished = True
|
359 |
+
|
360 |
+
# stop if we exceed the maximum length
|
361 |
+
if stopping_criteria(input_ids, scores):
|
362 |
+
this_peer_finished = True
|
363 |
+
|
364 |
+
if this_peer_finished and not synced_gpus:
|
365 |
+
break
|
366 |
+
|
367 |
+
if streamer is not None:
|
368 |
+
streamer.end()
|
369 |
+
|
370 |
+
# if return_dict_in_generate:
|
371 |
+
# if self.config.is_encoder_decoder:
|
372 |
+
# return GenerateEncoderDecoderOutput(
|
373 |
+
# sequences=input_ids,
|
374 |
+
# scores=scores,
|
375 |
+
# logits=raw_logits,
|
376 |
+
# encoder_attentions=encoder_attentions,
|
377 |
+
# encoder_hidden_states=encoder_hidden_states,
|
378 |
+
# decoder_attentions=decoder_attentions,
|
379 |
+
# cross_attentions=cross_attentions,
|
380 |
+
# decoder_hidden_states=decoder_hidden_states,
|
381 |
+
# past_key_values=model_kwargs.get("past_key_values"),
|
382 |
+
# )
|
383 |
+
# else:
|
384 |
+
# return GenerateDecoderOnlyOutput(
|
385 |
+
# sequences=input_ids,
|
386 |
+
# scores=scores,
|
387 |
+
# logits=raw_logits,
|
388 |
+
# attentions=decoder_attentions,
|
389 |
+
# hidden_states=decoder_hidden_states,
|
390 |
+
# past_key_values=model_kwargs.get("past_key_values"),
|
391 |
+
# )
|
392 |
+
# else:
|
393 |
+
# return input_ids
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
class TransformersEngine(BaseEngine):
|
398 |
+
@property
|
399 |
+
def max_position_embeddings(self) -> int:
|
400 |
+
return self._model.config.max_position_embeddings
|
401 |
+
|
402 |
+
@property
|
403 |
+
def tokenizer(self):
|
404 |
+
return self._tokenizer
|
405 |
+
|
406 |
+
def load_model(self):
|
407 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
408 |
+
import sys
|
409 |
+
# caution: path[0] is reserved for script path (or '' in REPL)
|
410 |
+
# sys.path.append(CODE_PATH)
|
411 |
+
self.model_path = model_path = MODEL_PATH
|
412 |
+
self.torch_dtype = torch.bfloat16 if DTYPE == 'bfloat16' else torch.float16
|
413 |
+
self.device_map = DEVICE
|
414 |
+
print(f'Loading model from {model_path} on {self.device_map} with {self.torch_dtype}')
|
415 |
+
|
416 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
417 |
+
assert self._tokenizer.chat_template is not None and self._tokenizer.chat_template != "", f"{self._tokenizer.chat_template=} not found!"
|
418 |
+
self._model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=self.torch_dtype, device_map=self.device_map, trust_remote_code=True).eval()
|
419 |
+
self._model.sample_old = self._model.sample
|
420 |
+
self._model.sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
421 |
+
print(self._model)
|
422 |
+
print(f"{self.max_position_embeddings=}")
|
423 |
+
|
424 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
425 |
+
|
426 |
+
# ! MUST PUT INSIDE torch.no_grad() otherwise it will overflow OOM
|
427 |
+
with torch.no_grad():
|
428 |
+
inputs = self.tokenizer(prompt, return_tensors='pt')
|
429 |
+
num_tokens = inputs.input_ids.size(1)
|
430 |
+
|
431 |
+
inputs = inputs.to(self.device_map)
|
432 |
+
|
433 |
+
generator = self._model.generate(
|
434 |
+
**inputs,
|
435 |
+
do_sample=True,
|
436 |
+
temperature=temperature,
|
437 |
+
max_new_tokens=max_tokens,
|
438 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
439 |
+
)
|
440 |
+
|
441 |
+
out_tokens = []
|
442 |
+
response = None
|
443 |
+
for token in generator:
|
444 |
+
out_tokens.append(token.item())
|
445 |
+
response = self.processor.tokenizer.decode(out_tokens)
|
446 |
+
num_tokens += 1
|
447 |
+
# print(f"{num_tokens=}", end='\r')
|
448 |
+
# sys.stdout.flush()
|
449 |
+
yield response, num_tokens
|
450 |
+
|
451 |
+
if response is not None:
|
452 |
+
full_text = prompt + response
|
453 |
+
num_tokens = len(self.tokenizer.encode(full_text))
|
454 |
+
yield response, num_tokens
|
multipurpose_chatbot/engines/vllm_engine.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
import gradio as gr
|
5 |
+
from typing import Any, Iterator
|
6 |
+
from typing import Iterator, List, Optional, Tuple
|
7 |
+
import filelock
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
from gradio.routes import Request
|
12 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
13 |
+
from gradio.helpers import special_args
|
14 |
+
import anyio
|
15 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
16 |
+
|
17 |
+
from gradio_client.documentation import document, set_documentation_group
|
18 |
+
|
19 |
+
from typing import List, Optional, Union, Dict, Tuple
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
from huggingface_hub import snapshot_download
|
22 |
+
|
23 |
+
from gradio.components import Button
|
24 |
+
from gradio.events import Dependency, EventListenerMethod
|
25 |
+
|
26 |
+
from .base_engine import BaseEngine
|
27 |
+
# @@ environments ================
|
28 |
+
|
29 |
+
from ..configs import (
|
30 |
+
DTYPE,
|
31 |
+
TENSOR_PARALLEL,
|
32 |
+
MODEL_PATH,
|
33 |
+
QUANTIZATION,
|
34 |
+
MAX_TOKENS,
|
35 |
+
TEMPERATURE,
|
36 |
+
FREQUENCE_PENALTY,
|
37 |
+
PRESENCE_PENALTY,
|
38 |
+
GPU_MEMORY_UTILIZATION,
|
39 |
+
STREAM_CHECK_MULTIPLE,
|
40 |
+
STREAM_YIELD_MULTIPLE,
|
41 |
+
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
llm = None
|
46 |
+
demo = None
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def vllm_abort(self):
|
51 |
+
sh = self.llm_engine.scheduler
|
52 |
+
for g in (sh.waiting + sh.running + sh.swapped):
|
53 |
+
sh.abort_seq_group(g.request_id)
|
54 |
+
from vllm.sequence import SequenceStatus
|
55 |
+
scheduler = self.llm_engine.scheduler
|
56 |
+
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
57 |
+
for seq_group in state_queue:
|
58 |
+
# if seq_group.request_id == request_id:
|
59 |
+
# Remove the sequence group from the state queue.
|
60 |
+
state_queue.remove(seq_group)
|
61 |
+
for seq in seq_group.seqs:
|
62 |
+
if seq.is_finished():
|
63 |
+
continue
|
64 |
+
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
65 |
+
|
66 |
+
|
67 |
+
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
68 |
+
from vllm.outputs import RequestOutput
|
69 |
+
# Initialize tqdm.
|
70 |
+
if use_tqdm:
|
71 |
+
num_requests = self.llm_engine.get_num_unfinished_requests()
|
72 |
+
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
73 |
+
# Run the engine.
|
74 |
+
outputs: Dict[str, RequestOutput] = {}
|
75 |
+
while self.llm_engine.has_unfinished_requests():
|
76 |
+
step_outputs = self.llm_engine.step()
|
77 |
+
for output in step_outputs:
|
78 |
+
outputs[output.request_id] = output
|
79 |
+
if len(outputs) > 0:
|
80 |
+
yield outputs
|
81 |
+
|
82 |
+
|
83 |
+
def vllm_generate_stream(
|
84 |
+
self: Any,
|
85 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
86 |
+
sampling_params: Optional[Any] = None,
|
87 |
+
prompt_token_ids: Optional[List[List[int]]] = None,
|
88 |
+
use_tqdm: bool = False,
|
89 |
+
) -> Dict[str, Any]:
|
90 |
+
"""Generates the completions for the input prompts.
|
91 |
+
|
92 |
+
NOTE: This class automatically batches the given prompts, considering
|
93 |
+
the memory constraint. For the best performance, put all of your prompts
|
94 |
+
into a single list and pass it to this method.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
prompts: A list of prompts to generate completions for.
|
98 |
+
sampling_params: The sampling parameters for text generation. If
|
99 |
+
None, we use the default sampling parameters.
|
100 |
+
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
101 |
+
use the tokenizer to convert the prompts to token IDs.
|
102 |
+
use_tqdm: Whether to use tqdm to display the progress bar.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
A list of `RequestOutput` objects containing the generated
|
106 |
+
completions in the same order as the input prompts.
|
107 |
+
"""
|
108 |
+
from vllm import LLM, SamplingParams
|
109 |
+
if prompts is None and prompt_token_ids is None:
|
110 |
+
raise ValueError("Either prompts or prompt_token_ids must be "
|
111 |
+
"provided.")
|
112 |
+
if isinstance(prompts, str):
|
113 |
+
# Convert a single prompt to a list.
|
114 |
+
prompts = [prompts]
|
115 |
+
if prompts is not None and prompt_token_ids is not None:
|
116 |
+
if len(prompts) != len(prompt_token_ids):
|
117 |
+
raise ValueError("The lengths of prompts and prompt_token_ids "
|
118 |
+
"must be the same.")
|
119 |
+
if sampling_params is None:
|
120 |
+
# Use default sampling params.
|
121 |
+
sampling_params = SamplingParams()
|
122 |
+
# Add requests to the engine.
|
123 |
+
if prompts is not None:
|
124 |
+
num_requests = len(prompts)
|
125 |
+
else:
|
126 |
+
num_requests = len(prompt_token_ids)
|
127 |
+
for i in range(num_requests):
|
128 |
+
prompt = prompts[i] if prompts is not None else None
|
129 |
+
if prompt_token_ids is None:
|
130 |
+
token_ids = None
|
131 |
+
else:
|
132 |
+
token_ids = prompt_token_ids[i]
|
133 |
+
self._add_request(prompt, sampling_params, token_ids)
|
134 |
+
# return self._run_engine(use_tqdm)
|
135 |
+
yield from _vllm_run_engine(self, use_tqdm)
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class VllmEngine(BaseEngine):
|
140 |
+
def __init__(self, **kwargs) -> None:
|
141 |
+
super().__init__(**kwargs)
|
142 |
+
|
143 |
+
@property
|
144 |
+
def tokenizer(self):
|
145 |
+
return self._model.get_tokenizer()
|
146 |
+
|
147 |
+
def load_model(self, ):
|
148 |
+
import torch
|
149 |
+
try:
|
150 |
+
compute_capability = torch.cuda.get_device_capability()
|
151 |
+
print(f'Torch CUDA compute_capability: {compute_capability}')
|
152 |
+
except Exception as e:
|
153 |
+
print(f'Failed to print compute_capability version: {e}')
|
154 |
+
|
155 |
+
import vllm
|
156 |
+
from vllm import LLM
|
157 |
+
|
158 |
+
print(f'VLLM: {vllm.__version__=}')
|
159 |
+
|
160 |
+
if QUANTIZATION == 'awq':
|
161 |
+
print(F'Load model in int4 quantization')
|
162 |
+
llm = LLM(
|
163 |
+
model=MODEL_PATH,
|
164 |
+
dtype="float16",
|
165 |
+
tensor_parallel_size=TENSOR_PARALLEL,
|
166 |
+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
167 |
+
quantization="awq",
|
168 |
+
max_model_len=MAX_TOKENS
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
llm = LLM(
|
172 |
+
model=MODEL_PATH,
|
173 |
+
dtype=DTYPE,
|
174 |
+
tensor_parallel_size=TENSOR_PARALLEL,
|
175 |
+
gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
|
176 |
+
max_model_len=MAX_TOKENS
|
177 |
+
)
|
178 |
+
|
179 |
+
try:
|
180 |
+
print(llm.llm_engine.workers[0].model)
|
181 |
+
except Exception as e:
|
182 |
+
print(f'Cannot print model worker: {e}')
|
183 |
+
|
184 |
+
try:
|
185 |
+
llm.llm_engine.scheduler_config.max_model_len = MAX_TOKENS
|
186 |
+
llm.llm_engine.scheduler_config.max_num_batched_tokens = MAX_TOKENS
|
187 |
+
except Exception as e:
|
188 |
+
print(f'Cannot set parameters: {e}')
|
189 |
+
|
190 |
+
self._model = llm
|
191 |
+
|
192 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
193 |
+
from vllm import SamplingParams
|
194 |
+
# ! must abort previous ones
|
195 |
+
vllm_abort(llm)
|
196 |
+
sampling_params = SamplingParams(
|
197 |
+
temperature=temperature,
|
198 |
+
max_tokens=max_tokens,
|
199 |
+
# frequency_penalty=frequency_penalty,
|
200 |
+
# presence_penalty=presence_penalty,
|
201 |
+
stop=stop_strings,
|
202 |
+
)
|
203 |
+
cur_out = None
|
204 |
+
num_tokens = len(self.tokenizer.encode(prompt))
|
205 |
+
for j, gen in enumerate(vllm_generate_stream(llm, prompt, sampling_params)):
|
206 |
+
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
207 |
+
yield cur_out, num_tokens
|
208 |
+
assert len(gen) == 1, f'{gen}'
|
209 |
+
item = next(iter(gen.values()))
|
210 |
+
cur_out = item.outputs[0].text
|
211 |
+
|
212 |
+
if cur_out is not None:
|
213 |
+
full_text = prompt + cur_out
|
214 |
+
num_tokens = len(self.tokenizer.encode(full_text))
|
215 |
+
yield cur_out, num_tokens
|
216 |
+
|
217 |
+
def batch_generate(self, prompts, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
218 |
+
"""
|
219 |
+
Only vllm should support this, the other engines is only batch=1 only
|
220 |
+
"""
|
221 |
+
from vllm import SamplingParams
|
222 |
+
# ! must abort previous ones
|
223 |
+
vllm_abort(llm)
|
224 |
+
sampling_params = SamplingParams(
|
225 |
+
temperature=temperature,
|
226 |
+
max_tokens=max_tokens,
|
227 |
+
# frequency_penalty=frequency_penalty,
|
228 |
+
# presence_penalty=presence_penalty,
|
229 |
+
stop=stop_strings,
|
230 |
+
)
|
231 |
+
generated = llm.generate(prompts, sampling_params, use_tqdm=False)
|
232 |
+
responses = [g.outputs[0].text for g in generated]
|
233 |
+
return responses
|
multipurpose_chatbot/globals.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
global MODEL_ENGINE
|
4 |
+
|
5 |
+
from multipurpose_chatbot.engines import load_multipurpose_chatbot_engine
|
6 |
+
from multipurpose_chatbot.demos import get_demo_class
|
7 |
+
|
8 |
+
from .configs import (
|
9 |
+
BACKEND,
|
10 |
+
RAG_EMBED_MODEL_NAME,
|
11 |
+
)
|
12 |
+
|
13 |
+
MODEL_ENGINE = load_multipurpose_chatbot_engine(BACKEND)
|
14 |
+
|
15 |
+
|
16 |
+
RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
|
17 |
+
|
18 |
+
|
19 |
+
def load_embeddings():
|
20 |
+
global RAG_EMBED
|
21 |
+
if RAG_EMBED is None:
|
22 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
23 |
+
print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
|
24 |
+
RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True, "device": "cpu"})
|
25 |
+
else:
|
26 |
+
print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
|
27 |
+
return RAG_EMBED
|
28 |
+
|
29 |
+
|
30 |
+
def get_rag_embeddings():
|
31 |
+
return load_embeddings()
|
32 |
+
|
33 |
+
|
pyproject.toml
ADDED
File without changes
|
requirements.txt
CHANGED
@@ -1,3 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
sentencepiece
|
2 |
accelerate
|
3 |
evaluate
|
@@ -10,21 +21,8 @@ jiwer
|
|
10 |
tenacity
|
11 |
pynvml
|
12 |
ninja
|
13 |
-
ray
|
14 |
-
psutil
|
15 |
fastapi
|
16 |
geomloss
|
17 |
einops
|
18 |
langdetect
|
19 |
-
transformers
|
20 |
-
transformers_stream_generator
|
21 |
plotly
|
22 |
-
vllm
|
23 |
-
langchain
|
24 |
-
langchain-community
|
25 |
-
langchain-core
|
26 |
-
sentence-transformers
|
27 |
-
faiss-cpu
|
28 |
-
pypdf
|
29 |
-
sentencepiece
|
30 |
-
docx2txt
|
|
|
1 |
+
torch
|
2 |
+
gradio
|
3 |
+
tiktoken
|
4 |
+
openai
|
5 |
+
transformers
|
6 |
+
langchain
|
7 |
+
langchain-community
|
8 |
+
langchain-core
|
9 |
+
chromadb
|
10 |
+
pypdf
|
11 |
+
docx2txt
|
12 |
sentencepiece
|
13 |
accelerate
|
14 |
evaluate
|
|
|
21 |
tenacity
|
22 |
pynvml
|
23 |
ninja
|
|
|
|
|
24 |
fastapi
|
25 |
geomloss
|
26 |
einops
|
27 |
langdetect
|
|
|
|
|
28 |
plotly
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seallm_app.py
ADDED
@@ -0,0 +1,1787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright: DAMO Academy, Alibaba Group
|
2 |
+
# By Xuan Phi Nguyen at DAMO Academy, Alibaba Group
|
3 |
+
|
4 |
+
# Description:
|
5 |
+
"""
|
6 |
+
VLLM-based demo script to launch Language chat model for Southeast Asian Languages
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
import argparse
|
13 |
+
import torch
|
14 |
+
import gradio as gr
|
15 |
+
from typing import Any, Iterator
|
16 |
+
from typing import Iterator, List, Optional, Tuple
|
17 |
+
import filelock
|
18 |
+
import glob
|
19 |
+
import json
|
20 |
+
import time
|
21 |
+
from gradio.routes import Request
|
22 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
23 |
+
from gradio.helpers import special_args
|
24 |
+
import anyio
|
25 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
26 |
+
|
27 |
+
from gradio_client.documentation import document, set_documentation_group
|
28 |
+
|
29 |
+
from typing import List, Optional, Union, Dict, Tuple
|
30 |
+
from tqdm.auto import tqdm
|
31 |
+
from huggingface_hub import snapshot_download
|
32 |
+
|
33 |
+
|
34 |
+
# @@ environments ================
|
35 |
+
|
36 |
+
DEBUG = bool(int(os.environ.get("DEBUG", "1")))
|
37 |
+
|
38 |
+
# List of languages to block
|
39 |
+
BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
|
40 |
+
BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
|
41 |
+
|
42 |
+
# for lang block, wether to block in history too
|
43 |
+
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
|
44 |
+
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
|
45 |
+
DTYPE = os.environ.get("DTYPE", "bfloat16")
|
46 |
+
|
47 |
+
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH
|
48 |
+
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0")))
|
49 |
+
LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0")))
|
50 |
+
# ! show model path in the demo page, only for internal
|
51 |
+
DISPLAY_MODEL_PATH = bool(int(os.environ.get("DISPLAY_MODEL_PATH", "1")))
|
52 |
+
|
53 |
+
# ! uploaded model path, will be downloaded to MODEL_PATH
|
54 |
+
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a")
|
55 |
+
# ! if model is private, need HF_TOKEN to access the model
|
56 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
57 |
+
# ! path where the model is downloaded, either on ./ or persistent disc
|
58 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a")
|
59 |
+
|
60 |
+
# ! log path
|
61 |
+
LOG_PATH = os.environ.get("LOG_PATH", "").strip()
|
62 |
+
LOG_FILE = None
|
63 |
+
SAVE_LOGS = LOG_PATH is not None and LOG_PATH != ''
|
64 |
+
if SAVE_LOGS:
|
65 |
+
if os.path.exists(LOG_PATH):
|
66 |
+
print(f'LOG_PATH exist: {LOG_PATH}')
|
67 |
+
else:
|
68 |
+
LOG_DIR = os.path.dirname(LOG_PATH)
|
69 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
70 |
+
|
71 |
+
# ! get LOG_PATH as aggregated outputs in log
|
72 |
+
GET_LOG_CMD = os.environ.get("GET_LOG_CMD", "").strip()
|
73 |
+
|
74 |
+
print(f'SAVE_LOGS: {SAVE_LOGS} | {LOG_PATH}')
|
75 |
+
# print(f'GET_LOG_CMD: {GET_LOG_CMD}')
|
76 |
+
|
77 |
+
# ! !! Whether to delete the folder, ONLY SET THIS IF YOU WANT TO DELETE SAVED MODEL ON PERSISTENT DISC
|
78 |
+
DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "")
|
79 |
+
IS_DELETE_FOLDER = DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER)
|
80 |
+
print(f'DELETE_FOLDER: {DELETE_FOLDER} | {DOWNLOAD_SNAPSHOT=}')
|
81 |
+
|
82 |
+
# ! list of keywords to disabled as security measures to comply with local regulation
|
83 |
+
KEYWORDS = os.environ.get("KEYWORDS", "").strip()
|
84 |
+
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
|
85 |
+
KEYWORDS = [x.lower() for x in KEYWORDS]
|
86 |
+
|
87 |
+
# bypass
|
88 |
+
BYPASS_USERS = os.environ.get("BYPASS_USERS", "").strip()
|
89 |
+
BYPASS_USERS = BYPASS_USERS.split(";") if len(BYPASS_USERS) > 0 else []
|
90 |
+
|
91 |
+
# gradio config
|
92 |
+
PORT = int(os.environ.get("PORT", "7860"))
|
93 |
+
# how many iterations to yield response
|
94 |
+
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
|
95 |
+
# how many iterations to perform safety check on response
|
96 |
+
STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0"))
|
97 |
+
|
98 |
+
# whether to enable to popup accept user
|
99 |
+
ENABLE_AGREE_POPUP = bool(int(os.environ.get("ENABLE_AGREE_POPUP", "0")))
|
100 |
+
|
101 |
+
# self explanatory
|
102 |
+
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048"))
|
103 |
+
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1"))
|
104 |
+
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.1"))
|
105 |
+
PRESENCE_PENALTY = float(os.environ.get("PRESENCE_PENALTY", "0.0"))
|
106 |
+
gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9"))
|
107 |
+
|
108 |
+
# whether to enable quantization, currently not in use
|
109 |
+
QUANTIZATION = str(os.environ.get("QUANTIZATION", ""))
|
110 |
+
|
111 |
+
|
112 |
+
# Batch inference file upload
|
113 |
+
ENABLE_BATCH_INFER = bool(int(os.environ.get("ENABLE_BATCH_INFER", "1")))
|
114 |
+
BATCH_INFER_MAX_ITEMS = int(os.environ.get("BATCH_INFER_MAX_ITEMS", "100"))
|
115 |
+
BATCH_INFER_MAX_FILE_SIZE = int(os.environ.get("BATCH_INFER_MAX_FILE_SIZE", "500"))
|
116 |
+
BATCH_INFER_MAX_PROMPT_TOKENS = int(os.environ.get("BATCH_INFER_MAX_PROMPT_TOKENS", "4000"))
|
117 |
+
BATCH_INFER_SAVE_TMP_FILE = os.environ.get("BATCH_INFER_SAVE_TMP_FILE", "./tmp/pred.json")
|
118 |
+
|
119 |
+
#
|
120 |
+
DATA_SET_REPO_PATH = str(os.environ.get("DATA_SET_REPO_PATH", ""))
|
121 |
+
DATA_SET_REPO = None
|
122 |
+
|
123 |
+
"""
|
124 |
+
Internal instructions of how to configure the DEMO
|
125 |
+
|
126 |
+
1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a
|
127 |
+
2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings
|
128 |
+
3. space config env: `HF_MODEL_NAME=SeaLLMs/seal-13b-chat-a` or the underlining model
|
129 |
+
4. If enable persistent storage: set
|
130 |
+
HF_HOME=/data/.huggingface
|
131 |
+
MODEL_PATH=/data/.huggingface/seal-13b-chat-a
|
132 |
+
if not:
|
133 |
+
MODEL_PATH=./seal-13b-chat-a
|
134 |
+
|
135 |
+
|
136 |
+
HF_HOME=/data/.huggingface
|
137 |
+
MODEL_PATH=/data/ckpt/seal-13b-chat-a
|
138 |
+
DELETE_FOLDER=/data/
|
139 |
+
|
140 |
+
"""
|
141 |
+
|
142 |
+
# ==============================
|
143 |
+
print(f'DEBUG mode: {DEBUG}')
|
144 |
+
print(f'Torch version: {torch.__version__}')
|
145 |
+
try:
|
146 |
+
print(f'Torch CUDA version: {torch.version.cuda}')
|
147 |
+
except Exception as e:
|
148 |
+
print(f'Failed to print cuda version: {e}')
|
149 |
+
|
150 |
+
try:
|
151 |
+
compute_capability = torch.cuda.get_device_capability()
|
152 |
+
print(f'Torch CUDA compute_capability: {compute_capability}')
|
153 |
+
except Exception as e:
|
154 |
+
print(f'Failed to print compute_capability version: {e}')
|
155 |
+
|
156 |
+
|
157 |
+
# @@ constants ================
|
158 |
+
|
159 |
+
DTYPES = {
|
160 |
+
'float16': torch.float16,
|
161 |
+
'bfloat16': torch.bfloat16
|
162 |
+
}
|
163 |
+
|
164 |
+
llm = None
|
165 |
+
demo = None
|
166 |
+
|
167 |
+
|
168 |
+
BOS_TOKEN = '<s>'
|
169 |
+
EOS_TOKEN = '</s>'
|
170 |
+
|
171 |
+
|
172 |
+
SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
# ######### RAG PREPARE
|
177 |
+
RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
|
178 |
+
|
179 |
+
# RAG_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
180 |
+
RAG_EMBED_MODEL_NAME = "sentence-transformers/LaBSE"
|
181 |
+
|
182 |
+
|
183 |
+
def load_embeddings():
|
184 |
+
global RAG_EMBED
|
185 |
+
if RAG_EMBED is None:
|
186 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
187 |
+
print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
|
188 |
+
RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True, "device": "cpu"})
|
189 |
+
else:
|
190 |
+
print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
|
191 |
+
return RAG_EMBED
|
192 |
+
|
193 |
+
|
194 |
+
def get_rag_embeddings():
|
195 |
+
return load_embeddings()
|
196 |
+
|
197 |
+
_ = get_rag_embeddings()
|
198 |
+
|
199 |
+
RAG_CURRENT_VECTORSTORE = None
|
200 |
+
|
201 |
+
def load_document_split_vectorstore(file_path):
|
202 |
+
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
203 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
204 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
205 |
+
from langchain_community.vectorstores import Chroma, FAISS
|
206 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
|
207 |
+
# assert RAG_EMBED is not None
|
208 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
|
209 |
+
if file_path.endswith('.pdf'):
|
210 |
+
loader = PyPDFLoader(file_path)
|
211 |
+
elif file_path.endswith('.docx'):
|
212 |
+
loader = Docx2txtLoader(file_path)
|
213 |
+
elif file_path.endswith('.txt'):
|
214 |
+
loader = TextLoader(file_path)
|
215 |
+
splits = loader.load_and_split(splitter)
|
216 |
+
RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
|
217 |
+
return RAG_CURRENT_VECTORSTORE
|
218 |
+
|
219 |
+
|
220 |
+
def docs_to_rag_context(docs: List[str]):
|
221 |
+
contexts = "\n".join([d.page_content for d in docs])
|
222 |
+
context = f"""Answer the following query exclusively based on the information provided in the document above. \
|
223 |
+
If the information is not found, please say so instead of making up facts! Remember to answer the question in the same language as the user query!
|
224 |
+
###
|
225 |
+
{contexts}
|
226 |
+
###
|
227 |
+
|
228 |
+
|
229 |
+
"""
|
230 |
+
return context
|
231 |
+
|
232 |
+
def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
|
233 |
+
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
234 |
+
doc_context = None
|
235 |
+
if file_input is not None:
|
236 |
+
assert os.path.exists(file_input), f"not found: {file_input}"
|
237 |
+
if file_input == RAG_CURRENT_FILE:
|
238 |
+
# reuse
|
239 |
+
vectorstore = RAG_CURRENT_VECTORSTORE
|
240 |
+
print(f'Reuse vectorstore: {file_input}')
|
241 |
+
else:
|
242 |
+
vectorstore = load_document_split_vectorstore(file_input)
|
243 |
+
print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
|
244 |
+
RAG_CURRENT_FILE = file_input
|
245 |
+
docs = vectorstore.similarity_search(message, k=rag_num_docs)
|
246 |
+
doc_context = docs_to_rag_context(docs)
|
247 |
+
return doc_context
|
248 |
+
|
249 |
+
# ######### RAG PREPARE
|
250 |
+
|
251 |
+
|
252 |
+
# ============ CONSTANT ============
|
253 |
+
# https://github.com/gradio-app/gradio/issues/884
|
254 |
+
MODEL_NAME = "SeaLLM-7B"
|
255 |
+
MODEL_NAME = str(os.environ.get("MODEL_NAME", "SeaLLM-7B"))
|
256 |
+
|
257 |
+
MODEL_TITLE = """
|
258 |
+
<div class="container" style="
|
259 |
+
align-items: center;
|
260 |
+
justify-content: center;
|
261 |
+
display: flex;
|
262 |
+
">
|
263 |
+
<div class="image" >
|
264 |
+
<img src="file/seal_logo.png" style="
|
265 |
+
max-width: 10em;
|
266 |
+
max-height: 5%;
|
267 |
+
height: 3em;
|
268 |
+
width: 3em;
|
269 |
+
float: left;
|
270 |
+
margin-left: auto;
|
271 |
+
">
|
272 |
+
</div>
|
273 |
+
<div class="text" style="
|
274 |
+
padding-left: 20px;
|
275 |
+
padding-top: 1%;
|
276 |
+
float: left;
|
277 |
+
">
|
278 |
+
<h1 style="font-size: xx-large">SeaLLMs - Large Language Models for Southeast Asia</h1>
|
279 |
+
</div>
|
280 |
+
</div>
|
281 |
+
"""
|
282 |
+
|
283 |
+
MODEL_TITLE = """
|
284 |
+
<img src="file/seal_logo.png" style="
|
285 |
+
max-width: 10em;
|
286 |
+
max-height: 5%;
|
287 |
+
height: 3em;
|
288 |
+
width: 3em;
|
289 |
+
">
|
290 |
+
<div class="text" style="
|
291 |
+
loat: left;
|
292 |
+
padding-bottom: 2%;
|
293 |
+
">
|
294 |
+
SeaLLMs - Large Language Models for Southeast Asia
|
295 |
+
</div>
|
296 |
+
"""
|
297 |
+
|
298 |
+
"""
|
299 |
+
Somehow cannot add image here
|
300 |
+
<div class="image" >
|
301 |
+
<img src="file/seal_logo.png" style="
|
302 |
+
max-width: 10em;
|
303 |
+
max-height: 5%;
|
304 |
+
height: 3em;
|
305 |
+
width: 3em;
|
306 |
+
float: left;
|
307 |
+
margin-left: auto;
|
308 |
+
">
|
309 |
+
</div>
|
310 |
+
"""
|
311 |
+
|
312 |
+
MODEL_DESC = f"""
|
313 |
+
<div style='display:flex; gap: 0.25rem; '>
|
314 |
+
<a href='https://github.com/damo-nlp-sg/seallms'><img src='https://img.shields.io/badge/Github-Code-success'></a>
|
315 |
+
<a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-7B'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
|
316 |
+
<a href='https://huggingface.co/SeaLLMs/SeaLLM-7B-v2'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
|
317 |
+
<a href='https://arxiv.org/pdf/2312.00738.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
|
318 |
+
</div>
|
319 |
+
<span style="font-size: larger">
|
320 |
+
<a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">{MODEL_NAME}-v2</a> - a helpful assistant for Southeast Asian Languages 🇬🇧 🇻🇳 🇮🇩 🇹🇭 🇲🇾 🇰🇭 🇱🇦 🇵🇭 🇲🇲.
|
321 |
+
Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-7B-v2" target="_blank">our article</a> for more.
|
322 |
+
</span>
|
323 |
+
<br>
|
324 |
+
<span>
|
325 |
+
<span style="color: red">NOTE: The chatbot may produce false and harmful content and does not have up-to-date knowledge.</span>
|
326 |
+
By using our service, you are required to agree to our <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b/blob/main/LICENSE" target="_blank" style="color: red">Terms Of Use</a>, which includes
|
327 |
+
not to use our service to generate any harmful, inappropriate or illegal content.
|
328 |
+
The service collects user dialogue data for testing and improvement under
|
329 |
+
<a href="https://creativecommons.org/licenses/by/4.0/">(CC-BY)</a> or similar license. So do not enter any personal information!
|
330 |
+
</span>
|
331 |
+
""".strip()
|
332 |
+
|
333 |
+
|
334 |
+
cite_markdown = """
|
335 |
+
## Citation
|
336 |
+
If you find our project useful, hope you can star our repo and cite our paper as follows:
|
337 |
+
```
|
338 |
+
@article{damonlpsg2023seallm,
|
339 |
+
author = {Xuan-Phi Nguyen*, Wenxuan Zhang*, Xin Li*, Mahani Aljunied*, Zhiqiang Hu, Chenhui Shen^, Yew Ken Chia^, Xingxuan Li, Jianyu Wang, Qingyu Tan, Liying Cheng, Guanzheng Chen, Yue Deng, Sen Yang, Chaoqun Liu, Hang Zhang, Lidong Bing},
|
340 |
+
title = {SeaLLMs - Large Language Models for Southeast Asia},
|
341 |
+
year = 2023,
|
342 |
+
}
|
343 |
+
```
|
344 |
+
"""
|
345 |
+
|
346 |
+
path_markdown = """
|
347 |
+
#### Model path:
|
348 |
+
{model_path}
|
349 |
+
"""
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
# ! ==================================================================
|
354 |
+
|
355 |
+
set_documentation_group("component")
|
356 |
+
|
357 |
+
|
358 |
+
RES_PRINTED = False
|
359 |
+
|
360 |
+
|
361 |
+
@document()
|
362 |
+
class ChatBot(gr.Chatbot):
|
363 |
+
def _postprocess_chat_messages(
|
364 |
+
self, chat_message
|
365 |
+
):
|
366 |
+
x = super()._postprocess_chat_messages(chat_message)
|
367 |
+
# if isinstance(x, str):
|
368 |
+
# x = x.strip().replace("\n", "<br>")
|
369 |
+
return x
|
370 |
+
|
371 |
+
|
372 |
+
from gradio.components import Button
|
373 |
+
from gradio.events import Dependency, EventListenerMethod
|
374 |
+
|
375 |
+
# replace events so that submit button is disabled during generation, if stop_btn not found
|
376 |
+
# this prevent weird behavior
|
377 |
+
def _setup_stop_events(
|
378 |
+
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
379 |
+
) -> None:
|
380 |
+
from gradio.components import State
|
381 |
+
event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
|
382 |
+
if self.stop_btn and self.is_generator:
|
383 |
+
if self.submit_btn:
|
384 |
+
for event_trigger in event_triggers:
|
385 |
+
event_trigger(
|
386 |
+
lambda: (
|
387 |
+
Button(visible=False),
|
388 |
+
Button(visible=True),
|
389 |
+
),
|
390 |
+
None,
|
391 |
+
[self.submit_btn, self.stop_btn],
|
392 |
+
api_name=False,
|
393 |
+
queue=False,
|
394 |
+
)
|
395 |
+
event_to_cancel.then(
|
396 |
+
lambda: (Button(visible=True), Button(visible=False)),
|
397 |
+
None,
|
398 |
+
[self.submit_btn, self.stop_btn],
|
399 |
+
api_name=False,
|
400 |
+
queue=False,
|
401 |
+
)
|
402 |
+
else:
|
403 |
+
for event_trigger in event_triggers:
|
404 |
+
event_trigger(
|
405 |
+
lambda: Button(visible=True),
|
406 |
+
None,
|
407 |
+
[self.stop_btn],
|
408 |
+
api_name=False,
|
409 |
+
queue=False,
|
410 |
+
)
|
411 |
+
event_to_cancel.then(
|
412 |
+
lambda: Button(visible=False),
|
413 |
+
None,
|
414 |
+
[self.stop_btn],
|
415 |
+
api_name=False,
|
416 |
+
queue=False,
|
417 |
+
)
|
418 |
+
self.stop_btn.click(
|
419 |
+
None,
|
420 |
+
None,
|
421 |
+
None,
|
422 |
+
cancels=event_to_cancel,
|
423 |
+
api_name=False,
|
424 |
+
)
|
425 |
+
else:
|
426 |
+
if self.submit_btn:
|
427 |
+
for event_trigger in event_triggers:
|
428 |
+
event_trigger(
|
429 |
+
lambda: Button(interactive=False),
|
430 |
+
None,
|
431 |
+
[self.submit_btn],
|
432 |
+
api_name=False,
|
433 |
+
queue=False,
|
434 |
+
)
|
435 |
+
event_to_cancel.then(
|
436 |
+
lambda: Button(interactive=True),
|
437 |
+
None,
|
438 |
+
[self.submit_btn],
|
439 |
+
api_name=False,
|
440 |
+
queue=False,
|
441 |
+
)
|
442 |
+
# upon clear, cancel the submit event as well
|
443 |
+
if self.clear_btn:
|
444 |
+
self.clear_btn.click(
|
445 |
+
lambda: ([], [], None, Button(interactive=True)),
|
446 |
+
None,
|
447 |
+
[self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn],
|
448 |
+
queue=False,
|
449 |
+
api_name=False,
|
450 |
+
cancels=event_to_cancel,
|
451 |
+
)
|
452 |
+
|
453 |
+
# TODO: reconfigure clear button as stop and clear button
|
454 |
+
def _setup_events(self) -> None:
|
455 |
+
from gradio.components import State
|
456 |
+
has_on = False
|
457 |
+
try:
|
458 |
+
from gradio.events import Dependency, EventListenerMethod, on
|
459 |
+
has_on = True
|
460 |
+
except ImportError as ie:
|
461 |
+
has_on = False
|
462 |
+
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
463 |
+
|
464 |
+
def update_time(c_time, chatbot_state):
|
465 |
+
# if chatbot_state is empty, register a new conversaion with the current timestamp
|
466 |
+
# assert len(chatbot_state) > 0, f'empty chatbot state'
|
467 |
+
if len(chatbot_state) <= 1:
|
468 |
+
return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
|
469 |
+
# elif len(chatbot_state) == 1:
|
470 |
+
# # assert chatbot_state[-1][-1] is None, f'invalid [[message, None]] , got {chatbot_state}'
|
471 |
+
# return gr.Number(value=time.time(), label='current_time', visible=False), chatbot_state
|
472 |
+
else:
|
473 |
+
return c_time, chatbot_state
|
474 |
+
|
475 |
+
if has_on:
|
476 |
+
# new version
|
477 |
+
submit_triggers = (
|
478 |
+
[self.textbox.submit, self.submit_btn.click]
|
479 |
+
if self.submit_btn
|
480 |
+
else [self.textbox.submit]
|
481 |
+
)
|
482 |
+
submit_event = (
|
483 |
+
on(
|
484 |
+
submit_triggers,
|
485 |
+
self._clear_and_save_textbox,
|
486 |
+
[self.textbox],
|
487 |
+
[self.textbox, self.saved_input],
|
488 |
+
api_name=False,
|
489 |
+
queue=False,
|
490 |
+
)
|
491 |
+
.then(
|
492 |
+
self._display_input,
|
493 |
+
[self.saved_input, self.chatbot_state],
|
494 |
+
[self.chatbot, self.chatbot_state],
|
495 |
+
api_name=False,
|
496 |
+
queue=False,
|
497 |
+
)
|
498 |
+
.then(
|
499 |
+
update_time,
|
500 |
+
[self.additional_inputs[-1], self.chatbot_state],
|
501 |
+
[self.additional_inputs[-1], self.chatbot_state],
|
502 |
+
api_name=False,
|
503 |
+
queue=False,
|
504 |
+
)
|
505 |
+
.then(
|
506 |
+
submit_fn,
|
507 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
508 |
+
[self.chatbot, self.chatbot_state],
|
509 |
+
api_name=False,
|
510 |
+
)
|
511 |
+
)
|
512 |
+
self._setup_stop_events(submit_triggers, submit_event)
|
513 |
+
else:
|
514 |
+
raise ValueError(f'Better install new gradio version than 3.44.0')
|
515 |
+
|
516 |
+
if self.retry_btn:
|
517 |
+
retry_event = (
|
518 |
+
self.retry_btn.click(
|
519 |
+
self._delete_prev_fn,
|
520 |
+
[self.chatbot_state],
|
521 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
522 |
+
api_name=False,
|
523 |
+
queue=False,
|
524 |
+
)
|
525 |
+
.then(
|
526 |
+
self._display_input,
|
527 |
+
[self.saved_input, self.chatbot_state],
|
528 |
+
[self.chatbot, self.chatbot_state],
|
529 |
+
api_name=False,
|
530 |
+
queue=False,
|
531 |
+
)
|
532 |
+
.then(
|
533 |
+
submit_fn,
|
534 |
+
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
535 |
+
[self.chatbot, self.chatbot_state],
|
536 |
+
api_name=False,
|
537 |
+
)
|
538 |
+
)
|
539 |
+
self._setup_stop_events([self.retry_btn.click], retry_event)
|
540 |
+
|
541 |
+
if self.undo_btn:
|
542 |
+
self.undo_btn.click(
|
543 |
+
self._delete_prev_fn,
|
544 |
+
[self.chatbot_state],
|
545 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
546 |
+
api_name=False,
|
547 |
+
queue=False,
|
548 |
+
).then(
|
549 |
+
lambda x: x,
|
550 |
+
[self.saved_input],
|
551 |
+
[self.textbox],
|
552 |
+
api_name=False,
|
553 |
+
queue=False,
|
554 |
+
)
|
555 |
+
|
556 |
+
# Reconfigure clear_btn to stop and clear text box
|
557 |
+
|
558 |
+
|
559 |
+
def _display_input(
|
560 |
+
self, message: str, history: List[List[Union[str, None]]]
|
561 |
+
) -> Tuple[List[List[Union[str, None]]], List[List[list[Union[str, None]]]]]:
|
562 |
+
if message is not None and message.strip() != "":
|
563 |
+
history.append([message, None])
|
564 |
+
return history, history
|
565 |
+
|
566 |
+
|
567 |
+
async def _stream_fn(
|
568 |
+
self,
|
569 |
+
message: str,
|
570 |
+
history_with_input,
|
571 |
+
request: Request,
|
572 |
+
*args,
|
573 |
+
) -> AsyncGenerator:
|
574 |
+
history = history_with_input[:-1]
|
575 |
+
inputs, _, _ = special_args(
|
576 |
+
self.fn, inputs=[message, history, *args], request=request
|
577 |
+
)
|
578 |
+
|
579 |
+
if self.is_async:
|
580 |
+
generator = self.fn(*inputs)
|
581 |
+
else:
|
582 |
+
generator = await anyio.to_thread.run_sync(
|
583 |
+
self.fn, *inputs, limiter=self.limiter
|
584 |
+
)
|
585 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
586 |
+
try:
|
587 |
+
first_response = await async_iteration(generator)
|
588 |
+
update = history + [[message, first_response]]
|
589 |
+
yield update, update
|
590 |
+
except StopIteration:
|
591 |
+
update = history + [[message, None]]
|
592 |
+
yield update, update
|
593 |
+
except Exception as e:
|
594 |
+
yield history, history
|
595 |
+
raise e
|
596 |
+
|
597 |
+
try:
|
598 |
+
async for response in generator:
|
599 |
+
update = history + [[message, response]]
|
600 |
+
yield update, update
|
601 |
+
except Exception as e:
|
602 |
+
# if "invalid" in str(e):
|
603 |
+
# yield history, history
|
604 |
+
# raise e
|
605 |
+
# else:
|
606 |
+
# raise e
|
607 |
+
yield history, history
|
608 |
+
raise e
|
609 |
+
|
610 |
+
|
611 |
+
|
612 |
+
|
613 |
+
# replace
|
614 |
+
gr.ChatInterface._setup_stop_events = _setup_stop_events
|
615 |
+
gr.ChatInterface._setup_events = _setup_events
|
616 |
+
gr.ChatInterface._display_input = _display_input
|
617 |
+
gr.ChatInterface._stream_fn = _stream_fn
|
618 |
+
|
619 |
+
|
620 |
+
@document()
|
621 |
+
class CustomTabbedInterface(gr.Blocks):
|
622 |
+
def __init__(
|
623 |
+
self,
|
624 |
+
interface_list: list[gr.Interface],
|
625 |
+
tab_names: Optional[list[str]] = None,
|
626 |
+
title: Optional[str] = None,
|
627 |
+
description: Optional[str] = None,
|
628 |
+
theme: Optional[gr.Theme] = None,
|
629 |
+
analytics_enabled: Optional[bool] = None,
|
630 |
+
css: Optional[str] = None,
|
631 |
+
):
|
632 |
+
"""
|
633 |
+
Parameters:
|
634 |
+
interface_list: a list of interfaces to be rendered in tabs.
|
635 |
+
tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
|
636 |
+
title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
|
637 |
+
analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
|
638 |
+
css: custom css or path to custom css file to apply to entire Blocks
|
639 |
+
Returns:
|
640 |
+
a Gradio Tabbed Interface for the given interfaces
|
641 |
+
"""
|
642 |
+
super().__init__(
|
643 |
+
title=title or "Gradio",
|
644 |
+
theme=theme,
|
645 |
+
analytics_enabled=analytics_enabled,
|
646 |
+
mode="tabbed_interface",
|
647 |
+
css=css,
|
648 |
+
)
|
649 |
+
self.description = description
|
650 |
+
if tab_names is None:
|
651 |
+
tab_names = [f"Tab {i}" for i in range(len(interface_list))]
|
652 |
+
with self:
|
653 |
+
if title:
|
654 |
+
gr.Markdown(
|
655 |
+
f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
|
656 |
+
)
|
657 |
+
if description:
|
658 |
+
gr.Markdown(description)
|
659 |
+
with gr.Tabs():
|
660 |
+
for interface, tab_name in zip(interface_list, tab_names):
|
661 |
+
with gr.Tab(label=tab_name):
|
662 |
+
interface.render()
|
663 |
+
|
664 |
+
|
665 |
+
def vllm_abort(self):
|
666 |
+
sh = self.llm_engine.scheduler
|
667 |
+
for g in (sh.waiting + sh.running + sh.swapped):
|
668 |
+
sh.abort_seq_group(g.request_id)
|
669 |
+
from vllm.sequence import SequenceStatus
|
670 |
+
scheduler = self.llm_engine.scheduler
|
671 |
+
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]:
|
672 |
+
for seq_group in state_queue:
|
673 |
+
# if seq_group.request_id == request_id:
|
674 |
+
# Remove the sequence group from the state queue.
|
675 |
+
state_queue.remove(seq_group)
|
676 |
+
for seq in seq_group.seqs:
|
677 |
+
if seq.is_finished():
|
678 |
+
continue
|
679 |
+
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
680 |
+
|
681 |
+
|
682 |
+
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
|
683 |
+
from vllm.outputs import RequestOutput
|
684 |
+
# Initialize tqdm.
|
685 |
+
if use_tqdm:
|
686 |
+
num_requests = self.llm_engine.get_num_unfinished_requests()
|
687 |
+
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
688 |
+
# Run the engine.
|
689 |
+
outputs: Dict[str, RequestOutput] = {}
|
690 |
+
while self.llm_engine.has_unfinished_requests():
|
691 |
+
step_outputs = self.llm_engine.step()
|
692 |
+
for output in step_outputs:
|
693 |
+
outputs[output.request_id] = output
|
694 |
+
if len(outputs) > 0:
|
695 |
+
yield outputs
|
696 |
+
|
697 |
+
|
698 |
+
|
699 |
+
def vllm_generate_stream(
|
700 |
+
self: Any,
|
701 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
702 |
+
sampling_params: Optional[Any] = None,
|
703 |
+
prompt_token_ids: Optional[List[List[int]]] = None,
|
704 |
+
use_tqdm: bool = False,
|
705 |
+
) -> Dict[str, Any]:
|
706 |
+
"""Generates the completions for the input prompts.
|
707 |
+
|
708 |
+
NOTE: This class automatically batches the given prompts, considering
|
709 |
+
the memory constraint. For the best performance, put all of your prompts
|
710 |
+
into a single list and pass it to this method.
|
711 |
+
|
712 |
+
Args:
|
713 |
+
prompts: A list of prompts to generate completions for.
|
714 |
+
sampling_params: The sampling parameters for text generation. If
|
715 |
+
None, we use the default sampling parameters.
|
716 |
+
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
717 |
+
use the tokenizer to convert the prompts to token IDs.
|
718 |
+
use_tqdm: Whether to use tqdm to display the progress bar.
|
719 |
+
|
720 |
+
Returns:
|
721 |
+
A list of `RequestOutput` objects containing the generated
|
722 |
+
completions in the same order as the input prompts.
|
723 |
+
"""
|
724 |
+
from vllm import LLM, SamplingParams
|
725 |
+
if prompts is None and prompt_token_ids is None:
|
726 |
+
raise ValueError("Either prompts or prompt_token_ids must be "
|
727 |
+
"provided.")
|
728 |
+
if isinstance(prompts, str):
|
729 |
+
# Convert a single prompt to a list.
|
730 |
+
prompts = [prompts]
|
731 |
+
if prompts is not None and prompt_token_ids is not None:
|
732 |
+
if len(prompts) != len(prompt_token_ids):
|
733 |
+
raise ValueError("The lengths of prompts and prompt_token_ids "
|
734 |
+
"must be the same.")
|
735 |
+
if sampling_params is None:
|
736 |
+
# Use default sampling params.
|
737 |
+
sampling_params = SamplingParams()
|
738 |
+
|
739 |
+
# Add requests to the engine.
|
740 |
+
if prompts is not None:
|
741 |
+
num_requests = len(prompts)
|
742 |
+
else:
|
743 |
+
num_requests = len(prompt_token_ids)
|
744 |
+
for i in range(num_requests):
|
745 |
+
prompt = prompts[i] if prompts is not None else None
|
746 |
+
if prompt_token_ids is None:
|
747 |
+
token_ids = None
|
748 |
+
else:
|
749 |
+
token_ids = prompt_token_ids[i]
|
750 |
+
self._add_request(prompt, sampling_params, token_ids)
|
751 |
+
# return self._run_engine(use_tqdm)
|
752 |
+
yield from _vllm_run_engine(self, use_tqdm)
|
753 |
+
|
754 |
+
|
755 |
+
|
756 |
+
# ! avoid saying
|
757 |
+
# LANG_BLOCK_MESSAGE = """Sorry, the language you have asked is currently not supported. If you have questions in other supported languages, I'll be glad to help. \
|
758 |
+
# Please also consider clearing the chat box for a better experience."""
|
759 |
+
|
760 |
+
# KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help."
|
761 |
+
|
762 |
+
LANG_BLOCK_MESSAGE = """Unsupported language."""
|
763 |
+
|
764 |
+
KEYWORD_BLOCK_MESSAGE = "Invalid request."
|
765 |
+
|
766 |
+
|
767 |
+
def _detect_lang(text):
|
768 |
+
# Disable language that may have safety risk
|
769 |
+
from langdetect import detect as detect_lang
|
770 |
+
dlang = None
|
771 |
+
try:
|
772 |
+
dlang = detect_lang(text)
|
773 |
+
except Exception as e:
|
774 |
+
if "No features in text." in str(e):
|
775 |
+
return "en"
|
776 |
+
else:
|
777 |
+
return "zh"
|
778 |
+
return dlang
|
779 |
+
|
780 |
+
|
781 |
+
def block_lang(
|
782 |
+
message: str,
|
783 |
+
history: List[Tuple[str, str]] = None,
|
784 |
+
) -> str:
|
785 |
+
# relieve history base block
|
786 |
+
if len(BLOCK_LANGS) == 0:
|
787 |
+
return False
|
788 |
+
|
789 |
+
if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
|
790 |
+
return True
|
791 |
+
else:
|
792 |
+
_lang = _detect_lang(message)
|
793 |
+
if _lang in BLOCK_LANGS:
|
794 |
+
print(f'Detect blocked {_lang}: {message}')
|
795 |
+
return True
|
796 |
+
else:
|
797 |
+
return False
|
798 |
+
|
799 |
+
|
800 |
+
def safety_check(text, history=None, ) -> Optional[str]:
|
801 |
+
"""
|
802 |
+
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
803 |
+
This provides an additional security measure to enhance safety and compliance with local regulations.
|
804 |
+
"""
|
805 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
806 |
+
return KEYWORD_BLOCK_MESSAGE
|
807 |
+
|
808 |
+
if len(BLOCK_LANGS) > 0:
|
809 |
+
if block_lang(text, history):
|
810 |
+
return LANG_BLOCK_MESSAGE
|
811 |
+
|
812 |
+
return None
|
813 |
+
|
814 |
+
|
815 |
+
|
816 |
+
TURN_TEMPLATE = "<|im_start|>{role}\n{content}</s>"
|
817 |
+
TURN_PREFIX = "<|im_start|>{role}\n"
|
818 |
+
|
819 |
+
|
820 |
+
def chatml_chat_convo_format(conversations, add_assistant_prefix: bool, default_system=SYSTEM_PROMPT_1):
|
821 |
+
if conversations[0]['role'] != 'system':
|
822 |
+
conversations = [{"role": "system", "content": default_system}] + conversations
|
823 |
+
text = ''
|
824 |
+
for turn_id, turn in enumerate(conversations):
|
825 |
+
prompt = TURN_TEMPLATE.format(role=turn['role'], content=turn['content'])
|
826 |
+
text += prompt
|
827 |
+
if add_assistant_prefix:
|
828 |
+
prompt = TURN_PREFIX.format(role='assistant')
|
829 |
+
text += prompt
|
830 |
+
return text
|
831 |
+
|
832 |
+
|
833 |
+
def chatml_format(message, history=None, system_prompt=None):
|
834 |
+
conversations = []
|
835 |
+
system_prompt = system_prompt or "You are a helpful assistant."
|
836 |
+
if history is not None and len(history) > 0:
|
837 |
+
for i, (prompt, res) in enumerate(history):
|
838 |
+
conversations.append({"role": "user", "content": prompt.strip()})
|
839 |
+
conversations.append({"role": "assistant", "content": res.strip()})
|
840 |
+
conversations.append({"role": "user", "content": message.strip()})
|
841 |
+
return chatml_chat_convo_format(conversations, True, default_system=system_prompt)
|
842 |
+
|
843 |
+
|
844 |
+
def debug_chat_response_stream_multiturn(message, history):
|
845 |
+
message_safety = safety_check(message, history=history)
|
846 |
+
if message_safety is not None:
|
847 |
+
# yield message_safety
|
848 |
+
raise gr.Error(message_safety)
|
849 |
+
|
850 |
+
message = "This is a debugging message"
|
851 |
+
for i in range(len(message)):
|
852 |
+
time.sleep(0.05)
|
853 |
+
yield message[:i]
|
854 |
+
|
855 |
+
|
856 |
+
|
857 |
+
def chat_response_stream_multiturn(
|
858 |
+
message: str,
|
859 |
+
history: List[Tuple[str, str]],
|
860 |
+
temperature: float,
|
861 |
+
max_tokens: int,
|
862 |
+
frequency_penalty: float,
|
863 |
+
presence_penalty: float,
|
864 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
865 |
+
current_time: Optional[float] = None,
|
866 |
+
# profile: Optional[gr.OAuthProfile] = None,
|
867 |
+
) -> str:
|
868 |
+
"""
|
869 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
870 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
871 |
+
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
872 |
+
gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
873 |
+
gr.Textbox(value=sys_prompt, label='System prompt', lines=8, interactive=False),
|
874 |
+
gr.Number(value=0, label='current_time', visible=False),
|
875 |
+
"""
|
876 |
+
global LOG_FILE, LOG_PATH
|
877 |
+
if DEBUG:
|
878 |
+
yield from debug_chat_response_stream_multiturn(message, history)
|
879 |
+
return
|
880 |
+
from vllm import LLM, SamplingParams
|
881 |
+
"""Build multi turn
|
882 |
+
|
883 |
+
message is incoming prompt
|
884 |
+
history don't have the current messauge
|
885 |
+
"""
|
886 |
+
global llm, RES_PRINTED
|
887 |
+
assert llm is not None
|
888 |
+
assert system_prompt.strip() != '', f'system prompt is empty'
|
889 |
+
# is_by_pass = False if profile is None else profile.username in BYPASS_USERS
|
890 |
+
is_by_pass = False
|
891 |
+
|
892 |
+
tokenizer = llm.get_tokenizer()
|
893 |
+
# force removing all
|
894 |
+
vllm_abort(llm)
|
895 |
+
|
896 |
+
temperature = float(temperature)
|
897 |
+
frequency_penalty = float(frequency_penalty)
|
898 |
+
max_tokens = int(max_tokens)
|
899 |
+
|
900 |
+
message = message.strip()
|
901 |
+
|
902 |
+
if GET_LOG_CMD != "" and message.strip() == GET_LOG_CMD:
|
903 |
+
print_log_file()
|
904 |
+
yield "Finish printed log. Please clear the chatbox now."
|
905 |
+
return
|
906 |
+
|
907 |
+
if len(message) == 0:
|
908 |
+
raise gr.Error("The message cannot be empty!")
|
909 |
+
|
910 |
+
message_safety = safety_check(message, history=history)
|
911 |
+
if message_safety is not None and not is_by_pass:
|
912 |
+
# yield message_safety
|
913 |
+
raise gr.Error(message_safety)
|
914 |
+
|
915 |
+
# history will be appended with message later on
|
916 |
+
|
917 |
+
full_prompt = chatml_format(message.strip(), history=history, system_prompt=system_prompt)
|
918 |
+
print(full_prompt)
|
919 |
+
|
920 |
+
if len(tokenizer.encode(full_prompt)) >= 4050:
|
921 |
+
raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.")
|
922 |
+
|
923 |
+
sampling_params = SamplingParams(
|
924 |
+
temperature=temperature,
|
925 |
+
max_tokens=max_tokens,
|
926 |
+
frequency_penalty=frequency_penalty,
|
927 |
+
presence_penalty=presence_penalty,
|
928 |
+
# stop=['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'],
|
929 |
+
stop=['<s>', '</s>', '<|im_start|>', '<|im_end|>'],
|
930 |
+
)
|
931 |
+
cur_out = None
|
932 |
+
|
933 |
+
for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
|
934 |
+
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
935 |
+
# cur_out = cur_out.replace("\\n", "\n")
|
936 |
+
|
937 |
+
# optionally check safety, and respond
|
938 |
+
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
|
939 |
+
message_safety = safety_check(cur_out, history=None)
|
940 |
+
if message_safety is not None and not is_by_pass:
|
941 |
+
# yield message_safety
|
942 |
+
raise gr.Error(message_safety)
|
943 |
+
# return
|
944 |
+
|
945 |
+
yield cur_out
|
946 |
+
assert len(gen) == 1, f'{gen}'
|
947 |
+
item = next(iter(gen.values()))
|
948 |
+
cur_out = item.outputs[0].text
|
949 |
+
#cur_out = "Our system is under maintenance, will be back soon!"
|
950 |
+
if j >= max_tokens - 2:
|
951 |
+
gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
|
952 |
+
|
953 |
+
# TODO: use current_time to register conversations, accoriding history and cur_out
|
954 |
+
history_str = format_conversation(history + [[message, cur_out]])
|
955 |
+
print(f'@@@@@@@@@@\n{history_str}\n##########\n')
|
956 |
+
|
957 |
+
maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
|
958 |
+
|
959 |
+
if cur_out is not None and "\\n" in cur_out:
|
960 |
+
print(f'double slash-n in cur_out:\n{cur_out}')
|
961 |
+
cur_out = cur_out.replace("\\n", "\n")
|
962 |
+
|
963 |
+
if cur_out is not None:
|
964 |
+
yield cur_out
|
965 |
+
|
966 |
+
message_safety = safety_check(cur_out, history=None)
|
967 |
+
if message_safety is not None and not is_by_pass:
|
968 |
+
# yield message_safety
|
969 |
+
raise gr.Error(message_safety)
|
970 |
+
# return
|
971 |
+
|
972 |
+
|
973 |
+
|
974 |
+
def chat_response_stream_rag_multiturn(
|
975 |
+
message: str,
|
976 |
+
history: List[Tuple[str, str]],
|
977 |
+
file_input: str,
|
978 |
+
temperature: float,
|
979 |
+
max_tokens: int,
|
980 |
+
# frequency_penalty: float,
|
981 |
+
# presence_penalty: float,
|
982 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
983 |
+
current_time: Optional[float] = None,
|
984 |
+
rag_num_docs: Optional[int] = 3,
|
985 |
+
):
|
986 |
+
message = message.strip()
|
987 |
+
frequency_penalty = FREQUENCE_PENALTY
|
988 |
+
presence_penalty = PRESENCE_PENALTY
|
989 |
+
if len(message) == 0:
|
990 |
+
raise gr.Error("The message cannot be empty!")
|
991 |
+
doc_context = maybe_get_doc_context(message, file_input, rag_num_docs=rag_num_docs)
|
992 |
+
if doc_context is not None:
|
993 |
+
message = f"{doc_context}\n\n{message}"
|
994 |
+
yield from chat_response_stream_multiturn(
|
995 |
+
message, history, temperature, max_tokens, frequency_penalty,
|
996 |
+
presence_penalty, system_prompt, current_time
|
997 |
+
)
|
998 |
+
|
999 |
+
|
1000 |
+
def debug_generate_free_form_stream(message):
|
1001 |
+
output = " This is a debugging message...."
|
1002 |
+
for i in range(len(output)):
|
1003 |
+
time.sleep(0.05)
|
1004 |
+
yield message + output[:i]
|
1005 |
+
|
1006 |
+
|
1007 |
+
def generate_free_form_stream(
|
1008 |
+
message: str,
|
1009 |
+
temperature: float,
|
1010 |
+
max_tokens: int,
|
1011 |
+
frequency_penalty: float,
|
1012 |
+
presence_penalty: float,
|
1013 |
+
stop_strings: str = '<s>,</s>,<|im_start|>,<|im_end|>',
|
1014 |
+
current_time: Optional[float] = None,
|
1015 |
+
) -> str:
|
1016 |
+
global LOG_FILE, LOG_PATH
|
1017 |
+
if DEBUG:
|
1018 |
+
yield from debug_generate_free_form_stream(message)
|
1019 |
+
return
|
1020 |
+
from vllm import LLM, SamplingParams
|
1021 |
+
"""Build multi turn
|
1022 |
+
"""
|
1023 |
+
global llm, RES_PRINTED
|
1024 |
+
assert llm is not None
|
1025 |
+
tokenizer = llm.get_tokenizer()
|
1026 |
+
# force removing all
|
1027 |
+
vllm_abort(llm)
|
1028 |
+
|
1029 |
+
temperature = float(temperature)
|
1030 |
+
frequency_penalty = float(frequency_penalty)
|
1031 |
+
max_tokens = int(max_tokens)
|
1032 |
+
|
1033 |
+
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
1034 |
+
stop_strings = list(set(stop_strings + ['</s>', '<|im_start|>']))
|
1035 |
+
|
1036 |
+
sampling_params = SamplingParams(
|
1037 |
+
temperature=temperature,
|
1038 |
+
max_tokens=max_tokens,
|
1039 |
+
frequency_penalty=frequency_penalty,
|
1040 |
+
presence_penalty=presence_penalty,
|
1041 |
+
stop=stop_strings,
|
1042 |
+
# ignore_eos=True,
|
1043 |
+
)
|
1044 |
+
|
1045 |
+
# full_prompt = message
|
1046 |
+
if len(message) == 0:
|
1047 |
+
raise gr.Error("The message cannot be empty!")
|
1048 |
+
|
1049 |
+
message_safety = safety_check(message)
|
1050 |
+
if message_safety is not None:
|
1051 |
+
raise gr.Error(message_safety)
|
1052 |
+
|
1053 |
+
if len(tokenizer.encode(message)) >= 4050:
|
1054 |
+
raise gr.Error(f"Prompt is too long!")
|
1055 |
+
|
1056 |
+
cur_out = None
|
1057 |
+
for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
|
1058 |
+
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
|
1059 |
+
# optionally check safety, and respond
|
1060 |
+
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0:
|
1061 |
+
message_safety = safety_check(cur_out, history=None)
|
1062 |
+
if message_safety is not None:
|
1063 |
+
raise gr.Error(message_safety)
|
1064 |
+
yield message + cur_out
|
1065 |
+
assert len(gen) == 1, f'{gen}'
|
1066 |
+
item = next(iter(gen.values()))
|
1067 |
+
cur_out = item.outputs[0].text
|
1068 |
+
#cur_out = "Our system is under maintenance, will be back soon!"
|
1069 |
+
if j >= max_tokens - 2:
|
1070 |
+
gr.Warning(f'The response hits limit of {max_tokens} tokens. Consider increase the max tokens parameter in the Additional Inputs.')
|
1071 |
+
|
1072 |
+
if cur_out is not None:
|
1073 |
+
yield message + cur_out
|
1074 |
+
|
1075 |
+
message_safety = safety_check(message + cur_out, history=None)
|
1076 |
+
if message_safety is not None:
|
1077 |
+
raise gr.Error(message_safety)
|
1078 |
+
|
1079 |
+
|
1080 |
+
|
1081 |
+
|
1082 |
+
def maybe_log_conv_file(current_time, history, message, response, **kwargs):
|
1083 |
+
global LOG_FILE
|
1084 |
+
if LOG_FILE is not None:
|
1085 |
+
my_history = history + [[message, response]]
|
1086 |
+
obj = {
|
1087 |
+
'key': str(current_time),
|
1088 |
+
'history': my_history
|
1089 |
+
}
|
1090 |
+
for k, v in kwargs.items():
|
1091 |
+
obj[k] = v
|
1092 |
+
log_ = json.dumps(obj, ensure_ascii=False)
|
1093 |
+
LOG_FILE.write(log_ + "\n")
|
1094 |
+
LOG_FILE.flush()
|
1095 |
+
print(f'Wrote {obj["key"]} to {LOG_PATH}')
|
1096 |
+
|
1097 |
+
|
1098 |
+
def format_conversation(history):
|
1099 |
+
_str = '\n'.join([
|
1100 |
+
(
|
1101 |
+
f'<<<User>>> {h[0]}\n'
|
1102 |
+
f'<<<Asst>>> {h[1]}'
|
1103 |
+
)
|
1104 |
+
for h in history
|
1105 |
+
])
|
1106 |
+
return _str
|
1107 |
+
|
1108 |
+
|
1109 |
+
def aggregate_convos():
|
1110 |
+
from datetime import datetime
|
1111 |
+
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1112 |
+
assert os.path.exists(LOG_PATH), f'{LOG_PATH} not found'
|
1113 |
+
convos = None
|
1114 |
+
irregular_count = 1
|
1115 |
+
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
1116 |
+
convos = {}
|
1117 |
+
for i, l in enumerate(f):
|
1118 |
+
if l:
|
1119 |
+
item = json.loads(l)
|
1120 |
+
key = item['key']
|
1121 |
+
try:
|
1122 |
+
key = float(key)
|
1123 |
+
except Exception as e:
|
1124 |
+
key = -1
|
1125 |
+
if key > 0.0:
|
1126 |
+
item_key = datetime.fromtimestamp(key).strftime("%Y-%m-%d %H:%M:%S")
|
1127 |
+
else:
|
1128 |
+
key = item_key = f'e{irregular_count}'
|
1129 |
+
irregular_count += 1
|
1130 |
+
item['key'] = item_key
|
1131 |
+
convos[key] = item
|
1132 |
+
return convos
|
1133 |
+
|
1134 |
+
def maybe_upload_to_dataset():
|
1135 |
+
from datetime import datetime
|
1136 |
+
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1137 |
+
if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH != "":
|
1138 |
+
convos = aggregate_convos()
|
1139 |
+
AGG_LOG_PATH = LOG_PATH + ".agg.json"
|
1140 |
+
with open(AGG_LOG_PATH, 'w', encoding='utf-8') as fo:
|
1141 |
+
json.dump(convos, fo, indent=4, ensure_ascii=False)
|
1142 |
+
print(f'Saved aggregated json to {AGG_LOG_PATH}')
|
1143 |
+
try:
|
1144 |
+
from huggingface_hub import upload_file
|
1145 |
+
print(f'upload {AGG_LOG_PATH} to {DATA_SET_REPO_PATH}')
|
1146 |
+
upload_file(
|
1147 |
+
path_or_fileobj=AGG_LOG_PATH,
|
1148 |
+
path_in_repo=os.path.basename(AGG_LOG_PATH),
|
1149 |
+
repo_id=DATA_SET_REPO_PATH,
|
1150 |
+
token=HF_TOKEN,
|
1151 |
+
repo_type="dataset",
|
1152 |
+
create_pr=True
|
1153 |
+
)
|
1154 |
+
except Exception as e:
|
1155 |
+
print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
|
1156 |
+
|
1157 |
+
|
1158 |
+
def print_log_file():
|
1159 |
+
global LOG_FILE, LOG_PATH
|
1160 |
+
if SAVE_LOGS and os.path.exists(LOG_PATH):
|
1161 |
+
with open(LOG_PATH, 'r', encoding='utf-8') as f:
|
1162 |
+
convos = aggregate_convos()
|
1163 |
+
print(f'Printing log from {LOG_PATH}')
|
1164 |
+
items = list(convos.items())
|
1165 |
+
for k, v in items[-10:]:
|
1166 |
+
history = v.pop('history')
|
1167 |
+
print(f'######--{v}--#####')
|
1168 |
+
_str = format_conversation(history)
|
1169 |
+
print(_str)
|
1170 |
+
maybe_upload_to_dataset()
|
1171 |
+
|
1172 |
+
|
1173 |
+
def debug_chat_response_echo(
|
1174 |
+
message: str,
|
1175 |
+
history: List[Tuple[str, str]],
|
1176 |
+
temperature: float = 0.0,
|
1177 |
+
max_tokens: int = 4096,
|
1178 |
+
frequency_penalty: float = 0.4,
|
1179 |
+
presence_penalty: float = 0.0,
|
1180 |
+
current_time: Optional[float] = None,
|
1181 |
+
system_prompt: str = SYSTEM_PROMPT_1,
|
1182 |
+
) -> str:
|
1183 |
+
global LOG_FILE
|
1184 |
+
import time
|
1185 |
+
time.sleep(0.5)
|
1186 |
+
|
1187 |
+
if message.strip() == GET_LOG_CMD:
|
1188 |
+
print_log_file()
|
1189 |
+
yield "Finish printed log."
|
1190 |
+
return
|
1191 |
+
|
1192 |
+
for i in range(len(message)):
|
1193 |
+
yield f"repeat: {current_time} {message[:i + 1]}"
|
1194 |
+
|
1195 |
+
cur_out = f"repeat: {current_time} {message}"
|
1196 |
+
maybe_log_conv_file(current_time, history, message, cur_out, temperature=temperature, frequency_penalty=frequency_penalty)
|
1197 |
+
|
1198 |
+
|
1199 |
+
def check_model_path(model_path) -> str:
|
1200 |
+
assert os.path.exists(model_path), f'{model_path} not found'
|
1201 |
+
ckpt_info = "None"
|
1202 |
+
if os.path.isdir(model_path):
|
1203 |
+
if os.path.exists(f'{model_path}/info.txt'):
|
1204 |
+
with open(f'{model_path}/info.txt', 'r') as f:
|
1205 |
+
ckpt_info = f.read()
|
1206 |
+
print(f'Checkpoint info:\n{ckpt_info}\n-----')
|
1207 |
+
else:
|
1208 |
+
print(f'info.txt not found in {model_path}')
|
1209 |
+
print(f'model path dir: {list(os.listdir(model_path))}')
|
1210 |
+
|
1211 |
+
return ckpt_info
|
1212 |
+
|
1213 |
+
|
1214 |
+
def maybe_delete_folder():
|
1215 |
+
if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT:
|
1216 |
+
import shutil
|
1217 |
+
print(f'DELETE ALL FILES IN {DELETE_FOLDER}')
|
1218 |
+
for filename in os.listdir(DELETE_FOLDER):
|
1219 |
+
file_path = os.path.join(DELETE_FOLDER, filename)
|
1220 |
+
try:
|
1221 |
+
if os.path.isfile(file_path) or os.path.islink(file_path):
|
1222 |
+
os.unlink(file_path)
|
1223 |
+
elif os.path.isdir(file_path):
|
1224 |
+
shutil.rmtree(file_path)
|
1225 |
+
except Exception as e:
|
1226 |
+
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
1227 |
+
|
1228 |
+
|
1229 |
+
AGREE_POP_SCRIPTS = """
|
1230 |
+
async () => {
|
1231 |
+
alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
|
1232 |
+
}
|
1233 |
+
"""
|
1234 |
+
|
1235 |
+
def debug_file_function(
|
1236 |
+
files: Union[str, List[str]],
|
1237 |
+
prompt_mode: str,
|
1238 |
+
temperature: float,
|
1239 |
+
max_tokens: int,
|
1240 |
+
frequency_penalty: float,
|
1241 |
+
presence_penalty: float,
|
1242 |
+
stop_strings: str = "[STOP],<s>,</s>",
|
1243 |
+
current_time: Optional[float] = None,
|
1244 |
+
):
|
1245 |
+
"""This is only for debug purpose"""
|
1246 |
+
files = files if isinstance(files, list) else [files]
|
1247 |
+
print(files)
|
1248 |
+
filenames = [f.name for f in files]
|
1249 |
+
all_items = []
|
1250 |
+
for fname in filenames:
|
1251 |
+
print(f'Reading {fname}')
|
1252 |
+
with open(fname, 'r', encoding='utf-8') as f:
|
1253 |
+
items = json.load(f)
|
1254 |
+
assert isinstance(items, list), f'invalid items from {fname} not list'
|
1255 |
+
all_items.extend(items)
|
1256 |
+
print(all_items)
|
1257 |
+
print(f'{prompt_mode} / {temperature} / {max_tokens}, {frequency_penalty}, {presence_penalty}')
|
1258 |
+
save_path = "./test.json"
|
1259 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
1260 |
+
json.dump(all_items, f, indent=4, ensure_ascii=False)
|
1261 |
+
|
1262 |
+
for x in all_items:
|
1263 |
+
x['response'] = "Return response"
|
1264 |
+
|
1265 |
+
print_items = all_items[:1]
|
1266 |
+
# print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
|
1267 |
+
return save_path, print_items
|
1268 |
+
|
1269 |
+
|
1270 |
+
def validate_file_item(filename, index, item: Dict[str, str]):
|
1271 |
+
"""
|
1272 |
+
check safety for items in files
|
1273 |
+
"""
|
1274 |
+
message = item['prompt'].strip()
|
1275 |
+
|
1276 |
+
if len(message) == 0:
|
1277 |
+
raise gr.Error(f'Prompt {index} empty')
|
1278 |
+
|
1279 |
+
message_safety = safety_check(message, history=None)
|
1280 |
+
if message_safety is not None:
|
1281 |
+
raise gr.Error(f'Prompt {index} invalid: {message_safety}')
|
1282 |
+
|
1283 |
+
tokenizer = llm.get_tokenizer() if llm is not None else None
|
1284 |
+
if tokenizer is None or len(tokenizer.encode(message)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
|
1285 |
+
raise gr.Error(f"Prompt {index} too long, should be less than {BATCH_INFER_MAX_PROMPT_TOKENS} tokens")
|
1286 |
+
|
1287 |
+
|
1288 |
+
def read_validate_json_files(files: Union[str, List[str]]):
|
1289 |
+
files = files if isinstance(files, list) else [files]
|
1290 |
+
filenames = [f.name for f in files]
|
1291 |
+
all_items = []
|
1292 |
+
for fname in filenames:
|
1293 |
+
# check each files
|
1294 |
+
print(f'Reading {fname}')
|
1295 |
+
with open(fname, 'r', encoding='utf-8') as f:
|
1296 |
+
items = json.load(f)
|
1297 |
+
assert isinstance(items, list), f'Data {fname} not list'
|
1298 |
+
assert all(isinstance(x, dict) for x in items), f'item in input file not list'
|
1299 |
+
assert all("prompt" in x for x in items), f'key prompt should be in dict item of input file'
|
1300 |
+
|
1301 |
+
for i, x in enumerate(items):
|
1302 |
+
validate_file_item(fname, i, x)
|
1303 |
+
|
1304 |
+
all_items.extend(items)
|
1305 |
+
|
1306 |
+
if len(all_items) > BATCH_INFER_MAX_ITEMS:
|
1307 |
+
raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
|
1308 |
+
|
1309 |
+
return all_items, filenames
|
1310 |
+
|
1311 |
+
|
1312 |
+
def remove_gradio_cache(exclude_names=None):
|
1313 |
+
"""remove gradio cache to avoid flooding"""
|
1314 |
+
import shutil
|
1315 |
+
for root, dirs, files in os.walk('/tmp/gradio/'):
|
1316 |
+
for f in files:
|
1317 |
+
# if not any(f in ef for ef in except_files):
|
1318 |
+
if exclude_names is None or not any(ef in f for ef in exclude_names):
|
1319 |
+
print(f'Remove: {f}')
|
1320 |
+
os.unlink(os.path.join(root, f))
|
1321 |
+
# for d in dirs:
|
1322 |
+
# # if not any(d in ef for ef in except_files):
|
1323 |
+
# if exclude_names is None or not any(ef in d for ef in exclude_names):
|
1324 |
+
# print(f'Remove d: {d}')
|
1325 |
+
# shutil.rmtree(os.path.join(root, d))
|
1326 |
+
|
1327 |
+
|
1328 |
+
def maybe_upload_batch_set(pred_json_path):
|
1329 |
+
global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
|
1330 |
+
|
1331 |
+
if SAVE_LOGS and DATA_SET_REPO_PATH != "":
|
1332 |
+
try:
|
1333 |
+
from huggingface_hub import upload_file
|
1334 |
+
path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
|
1335 |
+
print(f'upload {pred_json_path} to {DATA_SET_REPO_PATH}//{path_in_repo}')
|
1336 |
+
upload_file(
|
1337 |
+
path_or_fileobj=pred_json_path,
|
1338 |
+
path_in_repo=path_in_repo,
|
1339 |
+
repo_id=DATA_SET_REPO_PATH,
|
1340 |
+
token=HF_TOKEN,
|
1341 |
+
repo_type="dataset",
|
1342 |
+
create_pr=True
|
1343 |
+
)
|
1344 |
+
except Exception as e:
|
1345 |
+
print(f'Failed to save to repo: {DATA_SET_REPO_PATH}|{str(e)}')
|
1346 |
+
|
1347 |
+
|
1348 |
+
def free_form_prompt(prompt, history=None, system_prompt=None):
|
1349 |
+
return prompt
|
1350 |
+
|
1351 |
+
def batch_inference(
|
1352 |
+
files: Union[str, List[str]],
|
1353 |
+
prompt_mode: str,
|
1354 |
+
temperature: float,
|
1355 |
+
max_tokens: int,
|
1356 |
+
frequency_penalty: float,
|
1357 |
+
presence_penalty: float,
|
1358 |
+
stop_strings: str = "[STOP],<s>,</s>,<|im_start|>",
|
1359 |
+
current_time: Optional[float] = None,
|
1360 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT_1
|
1361 |
+
):
|
1362 |
+
"""
|
1363 |
+
Handle file upload batch inference
|
1364 |
+
|
1365 |
+
"""
|
1366 |
+
global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
|
1367 |
+
if DEBUG:
|
1368 |
+
return debug_file_function(
|
1369 |
+
files, prompt_mode, temperature, max_tokens,
|
1370 |
+
presence_penalty, stop_strings, current_time)
|
1371 |
+
|
1372 |
+
from vllm import LLM, SamplingParams
|
1373 |
+
assert llm is not None
|
1374 |
+
# assert system_prompt.strip() != '', f'system prompt is empty'
|
1375 |
+
|
1376 |
+
stop_strings = [x.strip() for x in stop_strings.strip().split(",")]
|
1377 |
+
tokenizer = llm.get_tokenizer()
|
1378 |
+
# force removing all
|
1379 |
+
# NOTE: need to make sure all cached items are removed!!!!!!!!!
|
1380 |
+
vllm_abort(llm)
|
1381 |
+
|
1382 |
+
temperature = float(temperature)
|
1383 |
+
frequency_penalty = float(frequency_penalty)
|
1384 |
+
max_tokens = int(max_tokens)
|
1385 |
+
|
1386 |
+
all_items, filenames = read_validate_json_files(files)
|
1387 |
+
|
1388 |
+
# remove all items in /tmp/gradio/
|
1389 |
+
remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
|
1390 |
+
|
1391 |
+
if prompt_mode == 'chat':
|
1392 |
+
prompt_format_fn = chatml_format
|
1393 |
+
elif prompt_mode == 'few-shot':
|
1394 |
+
from functools import partial
|
1395 |
+
# prompt_format_fn = partial(
|
1396 |
+
# chatml_format, include_end_instruct=False
|
1397 |
+
# )
|
1398 |
+
prompt_format_fn = free_form_prompt
|
1399 |
+
else:
|
1400 |
+
raise gr.Error(f'Wrong mode {prompt_mode}')
|
1401 |
+
|
1402 |
+
full_prompts = [
|
1403 |
+
prompt_format_fn(
|
1404 |
+
x['prompt'], [], sys_prompt=system_prompt
|
1405 |
+
)
|
1406 |
+
for i, x in enumerate(all_items)
|
1407 |
+
]
|
1408 |
+
print(f'{full_prompts[0]}\n')
|
1409 |
+
|
1410 |
+
if any(len(tokenizer.encode(x)) >= 4090 for x in full_prompts):
|
1411 |
+
raise gr.Error(f"Some prompt is too long!")
|
1412 |
+
|
1413 |
+
stop_seq = list(set(['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'] + stop_strings))
|
1414 |
+
sampling_params = SamplingParams(
|
1415 |
+
temperature=temperature,
|
1416 |
+
max_tokens=max_tokens,
|
1417 |
+
frequency_penalty=frequency_penalty,
|
1418 |
+
presence_penalty=presence_penalty,
|
1419 |
+
stop=stop_seq
|
1420 |
+
)
|
1421 |
+
|
1422 |
+
generated = llm.generate(full_prompts, sampling_params, use_tqdm=False)
|
1423 |
+
responses = [g.outputs[0].text for g in generated]
|
1424 |
+
#responses = ["Our system is under maintenance, will be back soon!" for g in generated]
|
1425 |
+
if len(responses) != len(all_items):
|
1426 |
+
raise gr.Error(f'inconsistent lengths {len(responses)} != {len(all_items)}')
|
1427 |
+
|
1428 |
+
for res, item in zip(responses, all_items):
|
1429 |
+
item['response'] = res
|
1430 |
+
|
1431 |
+
save_path = BATCH_INFER_SAVE_TMP_FILE
|
1432 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
1433 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
1434 |
+
json.dump(all_items, f, indent=4, ensure_ascii=False)
|
1435 |
+
|
1436 |
+
# You need to upload save_path as a new timestamp file.
|
1437 |
+
maybe_upload_batch_set(save_path)
|
1438 |
+
|
1439 |
+
print_items = all_items[:2]
|
1440 |
+
# print_json = json.dumps(print_items, indent=4, ensure_ascii=False)
|
1441 |
+
return save_path, print_items
|
1442 |
+
|
1443 |
+
|
1444 |
+
# BATCH_INFER_MAX_ITEMS
|
1445 |
+
FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
|
1446 |
+
each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
|
1447 |
+
```
|
1448 |
+
[ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
|
1449 |
+
```
|
1450 |
+
"""
|
1451 |
+
|
1452 |
+
CHAT_EXAMPLES = [
|
1453 |
+
["Hãy giải thích thuyết tương đối rộng."],
|
1454 |
+
["Tolong bantu saya menulis email ke lembaga pemerintah untuk mencari dukungan finansial untuk penelitian AI."],
|
1455 |
+
["แนะนำ 10 จุดหมายปลายทางในกรุงเทพฯ"],
|
1456 |
+
]
|
1457 |
+
|
1458 |
+
|
1459 |
+
# performance items
|
1460 |
+
|
1461 |
+
def create_free_form_generation_demo():
|
1462 |
+
global short_model_path
|
1463 |
+
max_tokens = MAX_TOKENS
|
1464 |
+
temperature = TEMPERATURE
|
1465 |
+
frequence_penalty = FREQUENCE_PENALTY
|
1466 |
+
presence_penalty = PRESENCE_PENALTY
|
1467 |
+
|
1468 |
+
introduction = """
|
1469 |
+
### Free-form | Put any context string (like few-shot prompts)
|
1470 |
+
"""
|
1471 |
+
|
1472 |
+
with gr.Blocks() as demo_free_form:
|
1473 |
+
gr.Markdown(introduction)
|
1474 |
+
|
1475 |
+
with gr.Row():
|
1476 |
+
txt = gr.Textbox(
|
1477 |
+
scale=4,
|
1478 |
+
lines=16,
|
1479 |
+
show_label=False,
|
1480 |
+
placeholder="Enter any free form text and submit",
|
1481 |
+
container=False,
|
1482 |
+
)
|
1483 |
+
with gr.Row():
|
1484 |
+
free_submit_button = gr.Button('Submit')
|
1485 |
+
with gr.Row():
|
1486 |
+
temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random")
|
1487 |
+
length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation')
|
1488 |
+
freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens')
|
1489 |
+
pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens')
|
1490 |
+
stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1)
|
1491 |
+
|
1492 |
+
free_submit_button.click(
|
1493 |
+
generate_free_form_stream,
|
1494 |
+
[txt, temp, length, freq_pen, pres_pen, stop_strings],
|
1495 |
+
txt
|
1496 |
+
)
|
1497 |
+
return demo_free_form
|
1498 |
+
|
1499 |
+
|
1500 |
+
|
1501 |
+
def create_file_upload_demo():
|
1502 |
+
temperature = TEMPERATURE
|
1503 |
+
frequence_penalty = FREQUENCE_PENALTY
|
1504 |
+
presence_penalty = PRESENCE_PENALTY
|
1505 |
+
max_tokens = MAX_TOKENS
|
1506 |
+
demo_file_upload = gr.Interface(
|
1507 |
+
batch_inference,
|
1508 |
+
inputs=[
|
1509 |
+
gr.File(file_count='single', file_types=['json']),
|
1510 |
+
gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
|
1511 |
+
gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
|
1512 |
+
gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
|
1513 |
+
gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
|
1514 |
+
gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
|
1515 |
+
gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
|
1516 |
+
gr.Number(value=0, label='current_time', visible=False),
|
1517 |
+
],
|
1518 |
+
outputs=[
|
1519 |
+
# "file",
|
1520 |
+
gr.File(label="Generated file"),
|
1521 |
+
# "json"
|
1522 |
+
gr.JSON(label='Example outputs (display 2 samples)')
|
1523 |
+
],
|
1524 |
+
description=FILE_UPLOAD_DESCRIPTION,
|
1525 |
+
allow_flagging=False,
|
1526 |
+
examples=[
|
1527 |
+
["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "<s>,</s>,<|im_start|>"],
|
1528 |
+
["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "<s>,</s>,<|im_start|>,\\n"]
|
1529 |
+
],
|
1530 |
+
cache_examples=False,
|
1531 |
+
)
|
1532 |
+
return demo_file_upload
|
1533 |
+
|
1534 |
+
|
1535 |
+
def create_chat_demo(title=None, description=None):
|
1536 |
+
sys_prompt = SYSTEM_PROMPT_1
|
1537 |
+
max_tokens = MAX_TOKENS
|
1538 |
+
temperature = TEMPERATURE
|
1539 |
+
frequence_penalty = FREQUENCE_PENALTY
|
1540 |
+
presence_penalty = PRESENCE_PENALTY
|
1541 |
+
|
1542 |
+
demo_chat = gr.ChatInterface(
|
1543 |
+
chat_response_stream_multiturn,
|
1544 |
+
chatbot=ChatBot(
|
1545 |
+
label=MODEL_NAME,
|
1546 |
+
bubble_full_width=False,
|
1547 |
+
latex_delimiters=[
|
1548 |
+
{ "left": "$", "right": "$", "display": False},
|
1549 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1550 |
+
],
|
1551 |
+
show_copy_button=True,
|
1552 |
+
),
|
1553 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
|
1554 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1555 |
+
# ! consider preventing the stop button
|
1556 |
+
# stop_btn=None,
|
1557 |
+
title=title,
|
1558 |
+
description=description,
|
1559 |
+
additional_inputs=[
|
1560 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
1561 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
1562 |
+
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
1563 |
+
gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
1564 |
+
gr.Textbox(value=sys_prompt, label='System prompt', lines=4, interactive=False),
|
1565 |
+
gr.Number(value=0, label='current_time', visible=False),
|
1566 |
+
# ! Remove the system prompt textbox to avoid jailbreaking
|
1567 |
+
],
|
1568 |
+
examples=CHAT_EXAMPLES,
|
1569 |
+
cache_examples=False
|
1570 |
+
)
|
1571 |
+
return demo_chat
|
1572 |
+
|
1573 |
+
|
1574 |
+
def upload_file(file):
|
1575 |
+
# file_paths = [file.name for file in files]
|
1576 |
+
# return file_paths
|
1577 |
+
return file.name
|
1578 |
+
|
1579 |
+
|
1580 |
+
RAG_DESCRIPTION = """
|
1581 |
+
* Upload a doc below to answer question about it (RAG).
|
1582 |
+
* Every question must be explicit and self-contained! Because each prompt will invoke a new RAG retrieval without considering previous conversations.
|
1583 |
+
(E.g: Dont prompt "Answer my previous question in details.")
|
1584 |
+
"""
|
1585 |
+
|
1586 |
+
def create_chat_demo_rag(title=None, description=None):
|
1587 |
+
sys_prompt = SYSTEM_PROMPT_1
|
1588 |
+
max_tokens = MAX_TOKENS
|
1589 |
+
temperature = TEMPERATURE
|
1590 |
+
frequence_penalty = FREQUENCE_PENALTY
|
1591 |
+
presence_penalty = PRESENCE_PENALTY
|
1592 |
+
description = description or RAG_DESCRIPTION
|
1593 |
+
|
1594 |
+
# with gr.Blocks(title="RAG") as rag_demo:
|
1595 |
+
additional_inputs = [
|
1596 |
+
gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt', 'json']),
|
1597 |
+
# gr.Textbox(value=None, label='Document path', lines=1, interactive=False),
|
1598 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
1599 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
1600 |
+
# gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
1601 |
+
# gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
1602 |
+
gr.Textbox(value=sys_prompt, label='System prompt', lines=1, interactive=False),
|
1603 |
+
gr.Number(value=0, label='current_time', visible=False),
|
1604 |
+
]
|
1605 |
+
|
1606 |
+
demo_rag_chat = gr.ChatInterface(
|
1607 |
+
chat_response_stream_rag_multiturn,
|
1608 |
+
chatbot=gr.Chatbot(
|
1609 |
+
label=MODEL_NAME + "-RAG",
|
1610 |
+
bubble_full_width=False,
|
1611 |
+
latex_delimiters=[
|
1612 |
+
{ "left": "$", "right": "$", "display": False},
|
1613 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1614 |
+
],
|
1615 |
+
show_copy_button=True,
|
1616 |
+
),
|
1617 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
|
1618 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1619 |
+
# ! consider preventing the stop button
|
1620 |
+
# stop_btn=None,
|
1621 |
+
title=title,
|
1622 |
+
description=description,
|
1623 |
+
additional_inputs=additional_inputs,
|
1624 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1625 |
+
# examples=CHAT_EXAMPLES,
|
1626 |
+
cache_examples=False
|
1627 |
+
)
|
1628 |
+
# with demo_rag_chat:
|
1629 |
+
# upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt', 'json'], file_count="single")
|
1630 |
+
# upload_button.upload(upload_file, upload_button, additional_inputs[0])
|
1631 |
+
|
1632 |
+
# return demo_chat
|
1633 |
+
return demo_rag_chat
|
1634 |
+
|
1635 |
+
|
1636 |
+
|
1637 |
+
def launch_demo():
|
1638 |
+
global demo, llm, DEBUG, LOG_FILE
|
1639 |
+
model_desc = MODEL_DESC
|
1640 |
+
model_path = MODEL_PATH
|
1641 |
+
model_title = MODEL_TITLE
|
1642 |
+
hf_model_name = HF_MODEL_NAME
|
1643 |
+
tensor_parallel = TENSOR_PARALLEL
|
1644 |
+
assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
|
1645 |
+
dtype = DTYPE
|
1646 |
+
sys_prompt = SYSTEM_PROMPT_1
|
1647 |
+
max_tokens = MAX_TOKENS
|
1648 |
+
temperature = TEMPERATURE
|
1649 |
+
frequence_penalty = FREQUENCE_PENALTY
|
1650 |
+
presence_penalty = PRESENCE_PENALTY
|
1651 |
+
ckpt_info = "None"
|
1652 |
+
|
1653 |
+
print(
|
1654 |
+
f'Launch config: '
|
1655 |
+
f'\n| model_title=`{model_title}` '
|
1656 |
+
f'\n| max_tokens={max_tokens} '
|
1657 |
+
f'\n| dtype={dtype} '
|
1658 |
+
f'\n| tensor_parallel={tensor_parallel} '
|
1659 |
+
f'\n| IS_DELETE_FOLDER={IS_DELETE_FOLDER} '
|
1660 |
+
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} '
|
1661 |
+
f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} '
|
1662 |
+
f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} '
|
1663 |
+
f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} '
|
1664 |
+
f'\n| frequence_penalty={frequence_penalty} '
|
1665 |
+
f'\n| presence_penalty={presence_penalty} '
|
1666 |
+
f'\n| temperature={temperature} '
|
1667 |
+
# f'\n| hf_model_name={hf_model_name} '
|
1668 |
+
f'\n| model_path={model_path} '
|
1669 |
+
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} '
|
1670 |
+
f'\n| gpu_memory_utilization={gpu_memory_utilization} '
|
1671 |
+
f'\n| LOG_PATH={LOG_PATH} | SAVE_LOGS={SAVE_LOGS} '
|
1672 |
+
f'\n| Desc={model_desc}'
|
1673 |
+
)
|
1674 |
+
|
1675 |
+
if DEBUG:
|
1676 |
+
model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original"
|
1677 |
+
# response_fn = debug_chat_response_echo
|
1678 |
+
response_fn = chat_response_stream_multiturn
|
1679 |
+
print(f'Creating in DEBUG MODE')
|
1680 |
+
if SAVE_LOGS:
|
1681 |
+
LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
|
1682 |
+
else:
|
1683 |
+
# ! load the model
|
1684 |
+
maybe_delete_folder()
|
1685 |
+
|
1686 |
+
if DOWNLOAD_SNAPSHOT:
|
1687 |
+
print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}')
|
1688 |
+
if HF_TOKEN is not None:
|
1689 |
+
print(f'Load with HF_TOKEN: {HF_TOKEN}')
|
1690 |
+
snapshot_download(hf_model_name, local_dir=model_path, use_auth_token=True, token=HF_TOKEN)
|
1691 |
+
else:
|
1692 |
+
snapshot_download(hf_model_name, local_dir=model_path)
|
1693 |
+
|
1694 |
+
import vllm
|
1695 |
+
from vllm import LLM
|
1696 |
+
|
1697 |
+
print(F'VLLM: {vllm.__version__}')
|
1698 |
+
ckpt_info = check_model_path(model_path)
|
1699 |
+
|
1700 |
+
print(f'Load path: {model_path} | {ckpt_info}')
|
1701 |
+
|
1702 |
+
if QUANTIZATION == 'awq':
|
1703 |
+
print(F'Load model in int4 quantization')
|
1704 |
+
llm = LLM(model=model_path, dtype="float16", tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, quantization="awq", max_model_len=8192)
|
1705 |
+
else:
|
1706 |
+
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, max_model_len=8192)
|
1707 |
+
|
1708 |
+
try:
|
1709 |
+
print(llm.llm_engine.workers[0].model)
|
1710 |
+
except Exception as e:
|
1711 |
+
print(f'Cannot print model worker: {e}')
|
1712 |
+
|
1713 |
+
try:
|
1714 |
+
llm.llm_engine.scheduler_config.max_model_len = 8192
|
1715 |
+
llm.llm_engine.scheduler_config.max_num_batched_tokens = 8192
|
1716 |
+
# llm.llm_engine.tokenizer.add_special_tokens = False
|
1717 |
+
except Exception as e:
|
1718 |
+
print(f'Cannot set parameters: {e}')
|
1719 |
+
|
1720 |
+
print(f'Use system prompt:\n{sys_prompt}')
|
1721 |
+
|
1722 |
+
response_fn = chat_response_stream_multiturn
|
1723 |
+
print(F'respond: {response_fn}')
|
1724 |
+
|
1725 |
+
if SAVE_LOGS:
|
1726 |
+
LOG_FILE = open(LOG_PATH, 'a', encoding='utf-8')
|
1727 |
+
|
1728 |
+
if ENABLE_BATCH_INFER:
|
1729 |
+
|
1730 |
+
# demo_file_upload = create_file_upload_demo()
|
1731 |
+
|
1732 |
+
demo_free_form = create_free_form_generation_demo()
|
1733 |
+
|
1734 |
+
demo_chat = create_chat_demo()
|
1735 |
+
demo_chat_rag = create_chat_demo_rag(description=RAG_DESCRIPTION)
|
1736 |
+
descriptions = model_desc
|
1737 |
+
if DISPLAY_MODEL_PATH:
|
1738 |
+
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
1739 |
+
|
1740 |
+
demo = CustomTabbedInterface(
|
1741 |
+
interface_list=[
|
1742 |
+
demo_chat,
|
1743 |
+
demo_chat_rag,
|
1744 |
+
demo_free_form,
|
1745 |
+
# demo_file_upload,
|
1746 |
+
],
|
1747 |
+
tab_names=[
|
1748 |
+
"Chat Interface",
|
1749 |
+
"RAG Chat Interface",
|
1750 |
+
"Text completion",
|
1751 |
+
# "Batch Inference",
|
1752 |
+
],
|
1753 |
+
title=f"{model_title}",
|
1754 |
+
description=descriptions,
|
1755 |
+
)
|
1756 |
+
else:
|
1757 |
+
descriptions = model_desc
|
1758 |
+
if DISPLAY_MODEL_PATH:
|
1759 |
+
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
1760 |
+
|
1761 |
+
demo = create_chat_demo(title=f"{model_title}", description=descriptions)
|
1762 |
+
demo.title = MODEL_NAME
|
1763 |
+
|
1764 |
+
with demo:
|
1765 |
+
if DATA_SET_REPO_PATH != "":
|
1766 |
+
try:
|
1767 |
+
from performance_plot import attach_plot_to_demo
|
1768 |
+
attach_plot_to_demo(demo)
|
1769 |
+
except Exception as e:
|
1770 |
+
print(f'Fail to load DEMO plot: {str(e)}')
|
1771 |
+
|
1772 |
+
gr.Markdown(cite_markdown)
|
1773 |
+
if DISPLAY_MODEL_PATH:
|
1774 |
+
gr.Markdown(path_markdown.format(model_path=model_path))
|
1775 |
+
|
1776 |
+
if ENABLE_AGREE_POPUP:
|
1777 |
+
demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
|
1778 |
+
|
1779 |
+
# login_btn = gr.LoginButton()
|
1780 |
+
|
1781 |
+
demo.queue(api_open=False)
|
1782 |
+
return demo
|
1783 |
+
|
1784 |
+
|
1785 |
+
if __name__ == "__main__":
|
1786 |
+
demo = launch_demo()
|
1787 |
+
demo.launch(show_api=False, allowed_paths=["seal_logo.png"])
|
seammm_2.png
ADDED
Git LFS Details
|
transformers_requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
transformers
|
vllm_requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
vllm
|