Upload folder using huggingface_hub
Browse files- ChatApp/app.py +253 -0
- ChatApp/app_modules/__pycache__/overwrites.cpython-39.pyc +0 -0
- ChatApp/app_modules/__pycache__/presets.cpython-39.pyc +0 -0
- ChatApp/app_modules/__pycache__/utils.cpython-39.pyc +0 -0
- ChatApp/app_modules/overwrites.py +33 -0
- ChatApp/app_modules/presets.py +81 -0
- ChatApp/app_modules/utils.py +235 -0
- ChatApp/assets/custom.css +488 -0
- ChatApp/assets/custom.js +1 -0
- ChatApp/interface/__pycache__/base_interface.cpython-39.pyc +0 -0
- ChatApp/interface/__pycache__/empty_stub_interface.cpython-39.pyc +0 -0
- ChatApp/interface/__pycache__/hddr_llama_onnx_interface.cpython-39.pyc +0 -0
- ChatApp/interface/base_interface.py +6 -0
- ChatApp/interface/empty_stub_interface.py +39 -0
- ChatApp/interface/hddr_llama_onnx_interface.py +395 -0
- ChatApp/requirements.txt +18 -0
ChatApp/app.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import gradio as gr
|
5 |
+
import gc
|
6 |
+
from interface.hddr_llama_onnx_interface import LlamaOnnxInterface
|
7 |
+
from interface.empty_stub_interface import EmptyStubInterface
|
8 |
+
from ChatApp.app_modules.utils import (
|
9 |
+
reset_textbox,
|
10 |
+
transfer_input,
|
11 |
+
reset_state,
|
12 |
+
delete_last_conversation,
|
13 |
+
cancel_outputing,
|
14 |
+
)
|
15 |
+
from ChatApp.app_modules.presets import (
|
16 |
+
small_and_beautiful_theme,
|
17 |
+
title,
|
18 |
+
description_top,
|
19 |
+
description,
|
20 |
+
)
|
21 |
+
from ChatApp.app_modules.overwrites import postprocess
|
22 |
+
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.DEBUG,
|
25 |
+
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
26 |
+
)
|
27 |
+
|
28 |
+
# we can filter this dictionary at the start according to the actual available files on disk
|
29 |
+
empty_stub_model_name = "_Empty Stub_"
|
30 |
+
|
31 |
+
top_directory = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
32 |
+
|
33 |
+
tokenizer_path = os.path.join(top_directory, "tokenizer.model")
|
34 |
+
|
35 |
+
available_models = {
|
36 |
+
"Llama-2 13B Float16": {
|
37 |
+
"onnx_file": os.path.join(
|
38 |
+
top_directory, "FP16", "LlamaV2_13B_float16.onnx"
|
39 |
+
),
|
40 |
+
"tokenizer_path": tokenizer_path,
|
41 |
+
"embedding_file": os.path.join(top_directory, "embeddings.pth"),
|
42 |
+
},
|
43 |
+
"Llama-2 13B FP32": {
|
44 |
+
"onnx_file": os.path.join(
|
45 |
+
top_directory, "FP32", "LlamaV2_13B_float16.onnx"
|
46 |
+
),
|
47 |
+
"tokenizer_path": tokenizer_path,
|
48 |
+
"embedding_file": os.path.join(
|
49 |
+
top_directory, "embeddings.pth"
|
50 |
+
),
|
51 |
+
},
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
interface = EmptyStubInterface()
|
56 |
+
interface.initialize()
|
57 |
+
|
58 |
+
# interface = None
|
59 |
+
|
60 |
+
gr.Chatbot.postprocess = postprocess
|
61 |
+
|
62 |
+
with open("ChatApp/assets/custom.css", "r", encoding="utf-8") as f:
|
63 |
+
custom_css = f.read()
|
64 |
+
|
65 |
+
|
66 |
+
def change_model_listener(new_model_name):
|
67 |
+
if new_model_name is None:
|
68 |
+
new_model_name = empty_stub_model_name
|
69 |
+
|
70 |
+
global interface
|
71 |
+
|
72 |
+
# if a model exists - shut it down before trying to create the new one
|
73 |
+
if interface is not None:
|
74 |
+
interface.shutdown()
|
75 |
+
del interface
|
76 |
+
gc.collect()
|
77 |
+
|
78 |
+
logging.info(f"Creating a new model [{new_model_name}]")
|
79 |
+
if new_model_name == empty_stub_model_name:
|
80 |
+
interface = EmptyStubInterface()
|
81 |
+
interface.initialize()
|
82 |
+
else:
|
83 |
+
d = available_models[new_model_name]
|
84 |
+
interface = LlamaOnnxInterface(
|
85 |
+
onnx_file=d["onnx_file"],
|
86 |
+
tokenizer_path=d["tokenizer_path"],
|
87 |
+
embedding_file=d["embedding_file"],
|
88 |
+
)
|
89 |
+
interface.initialize()
|
90 |
+
|
91 |
+
return new_model_name
|
92 |
+
|
93 |
+
|
94 |
+
def interface_predict(*args):
|
95 |
+
global interface
|
96 |
+
res = interface.predict(*args)
|
97 |
+
|
98 |
+
for x in res:
|
99 |
+
yield x
|
100 |
+
|
101 |
+
|
102 |
+
def interface_retry(*args):
|
103 |
+
global interface
|
104 |
+
res = interface.retry(*args)
|
105 |
+
|
106 |
+
for x in res:
|
107 |
+
yield x
|
108 |
+
|
109 |
+
|
110 |
+
with gr.Blocks(css=custom_css, theme=small_and_beautiful_theme) as demo:
|
111 |
+
history = gr.State([])
|
112 |
+
user_question = gr.State("")
|
113 |
+
with gr.Row():
|
114 |
+
gr.HTML(title)
|
115 |
+
status_display = gr.Markdown("Success", elem_id="status_display")
|
116 |
+
gr.Markdown(description_top)
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column(scale=5):
|
120 |
+
with gr.Row():
|
121 |
+
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot", height=900)
|
122 |
+
with gr.Row():
|
123 |
+
with gr.Column(scale=12):
|
124 |
+
user_input = gr.Textbox(show_label=False, placeholder="Enter text")
|
125 |
+
with gr.Column(min_width=70, scale=1):
|
126 |
+
submit_button = gr.Button("Send")
|
127 |
+
with gr.Column(min_width=70, scale=1):
|
128 |
+
cancel_button = gr.Button("Stop")
|
129 |
+
with gr.Row():
|
130 |
+
empty_button = gr.Button(
|
131 |
+
"🧹 New Conversation",
|
132 |
+
)
|
133 |
+
retry_button = gr.Button("🔄 Regenerate")
|
134 |
+
delete_last_button = gr.Button("🗑️ Remove Last Turn")
|
135 |
+
with gr.Column():
|
136 |
+
with gr.Column(min_width=50, scale=1):
|
137 |
+
with gr.Tab(label="Parameter Setting"):
|
138 |
+
gr.Markdown("# Model")
|
139 |
+
model_name = gr.Dropdown(
|
140 |
+
choices=[empty_stub_model_name] + list(available_models.keys()),
|
141 |
+
label="Model",
|
142 |
+
show_label=False, # default="Empty STUB",
|
143 |
+
)
|
144 |
+
model_name.change(
|
145 |
+
change_model_listener, inputs=[model_name], outputs=[model_name]
|
146 |
+
)
|
147 |
+
|
148 |
+
gr.Markdown("# Parameters")
|
149 |
+
top_p = gr.Slider(
|
150 |
+
minimum=-0,
|
151 |
+
maximum=1.0,
|
152 |
+
value=0.9,
|
153 |
+
step=0.05,
|
154 |
+
interactive=True,
|
155 |
+
label="Top-p",
|
156 |
+
)
|
157 |
+
temperature = gr.Slider(
|
158 |
+
minimum=0.1,
|
159 |
+
maximum=2.0,
|
160 |
+
value=0.75,
|
161 |
+
step=0.1,
|
162 |
+
interactive=True,
|
163 |
+
label="Temperature",
|
164 |
+
)
|
165 |
+
max_length_tokens = gr.Slider(
|
166 |
+
minimum=0,
|
167 |
+
maximum=512,
|
168 |
+
value=256,
|
169 |
+
step=8,
|
170 |
+
interactive=True,
|
171 |
+
label="Max Generation Tokens",
|
172 |
+
)
|
173 |
+
max_context_length_tokens = gr.Slider(
|
174 |
+
minimum=0,
|
175 |
+
maximum=4096,
|
176 |
+
value=2048,
|
177 |
+
step=128,
|
178 |
+
interactive=True,
|
179 |
+
label="Max History Tokens",
|
180 |
+
)
|
181 |
+
gr.Markdown(description)
|
182 |
+
|
183 |
+
predict_args = dict(
|
184 |
+
# fn=interface.predict,
|
185 |
+
fn=interface_predict,
|
186 |
+
inputs=[
|
187 |
+
user_question,
|
188 |
+
chatbot,
|
189 |
+
history,
|
190 |
+
top_p,
|
191 |
+
temperature,
|
192 |
+
max_length_tokens,
|
193 |
+
max_context_length_tokens,
|
194 |
+
],
|
195 |
+
outputs=[chatbot, history, status_display],
|
196 |
+
show_progress=True,
|
197 |
+
)
|
198 |
+
retry_args = dict(
|
199 |
+
fn=interface_retry,
|
200 |
+
inputs=[
|
201 |
+
user_input,
|
202 |
+
chatbot,
|
203 |
+
history,
|
204 |
+
top_p,
|
205 |
+
temperature,
|
206 |
+
max_length_tokens,
|
207 |
+
max_context_length_tokens,
|
208 |
+
],
|
209 |
+
outputs=[chatbot, history, status_display],
|
210 |
+
show_progress=True,
|
211 |
+
)
|
212 |
+
|
213 |
+
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display])
|
214 |
+
|
215 |
+
# Chatbot
|
216 |
+
transfer_input_args = dict(
|
217 |
+
fn=transfer_input,
|
218 |
+
inputs=[user_input],
|
219 |
+
outputs=[user_question, user_input, submit_button],
|
220 |
+
show_progress=True,
|
221 |
+
)
|
222 |
+
|
223 |
+
predict_event1 = user_input.submit(**transfer_input_args).then(**predict_args)
|
224 |
+
|
225 |
+
predict_event2 = submit_button.click(**transfer_input_args).then(**predict_args)
|
226 |
+
|
227 |
+
empty_button.click(
|
228 |
+
reset_state,
|
229 |
+
outputs=[chatbot, history, status_display],
|
230 |
+
show_progress=True,
|
231 |
+
)
|
232 |
+
empty_button.click(**reset_args)
|
233 |
+
|
234 |
+
predict_event3 = retry_button.click(**retry_args)
|
235 |
+
|
236 |
+
delete_last_button.click(
|
237 |
+
delete_last_conversation,
|
238 |
+
[chatbot, history],
|
239 |
+
[chatbot, history, status_display],
|
240 |
+
show_progress=True,
|
241 |
+
)
|
242 |
+
cancel_button.click(
|
243 |
+
cancel_outputing,
|
244 |
+
[],
|
245 |
+
[status_display],
|
246 |
+
cancels=[predict_event1, predict_event2, predict_event3],
|
247 |
+
)
|
248 |
+
|
249 |
+
demo.load(change_model_listener, inputs=None, outputs=model_name)
|
250 |
+
|
251 |
+
demo.title = "Llama-2 Chat UI"
|
252 |
+
|
253 |
+
demo.queue(concurrency_count=1).launch()
|
ChatApp/app_modules/__pycache__/overwrites.cpython-39.pyc
ADDED
Binary file (1.15 kB). View file
|
|
ChatApp/app_modules/__pycache__/presets.cpython-39.pyc
ADDED
Binary file (1.92 kB). View file
|
|
ChatApp/app_modules/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (6.25 kB). View file
|
|
ChatApp/app_modules/overwrites.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
from app_modules.presets import gr
|
5 |
+
from app_modules.utils import detect_converted_mark, convert_asis, convert_mdtext
|
6 |
+
|
7 |
+
|
8 |
+
def postprocess(
|
9 |
+
self, y: List[Tuple[str | None, str | None]]
|
10 |
+
) -> List[Tuple[str | None, str | None]]:
|
11 |
+
"""
|
12 |
+
Parameters:
|
13 |
+
y: List of tuples representing the message and response pairs.
|
14 |
+
Each message and response should be a string,
|
15 |
+
which may be in Markdown format.
|
16 |
+
Returns:
|
17 |
+
List of tuples representing the message and response.
|
18 |
+
Each message and response will be a string of HTML.
|
19 |
+
"""
|
20 |
+
if y is None or y == []:
|
21 |
+
return []
|
22 |
+
temp = []
|
23 |
+
for x in y:
|
24 |
+
user, bot = x
|
25 |
+
if not detect_converted_mark(user):
|
26 |
+
user = convert_asis(user)
|
27 |
+
if not detect_converted_mark(bot):
|
28 |
+
bot = convert_mdtext(bot)
|
29 |
+
temp.append((user, bot))
|
30 |
+
return temp
|
31 |
+
|
32 |
+
|
33 |
+
GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
|
ChatApp/app_modules/presets.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
|
5 |
+
title = """<h1 align="left" style="min-width:200px; margin-top:0;">Llama-2 Chat UI</h1>"""
|
6 |
+
description_top = """\
|
7 |
+
<div align="left">
|
8 |
+
Use at your own risk...
|
9 |
+
</p >
|
10 |
+
</div>
|
11 |
+
"""
|
12 |
+
description = """\
|
13 |
+
<div align="center" style="margin:16px 0">
|
14 |
+
This is a chat demo using the ONNX versions of the Llama 2 model
|
15 |
+
</div>
|
16 |
+
"""
|
17 |
+
CONCURRENT_COUNT = 100
|
18 |
+
|
19 |
+
|
20 |
+
ALREADY_CONVERTED_MARK = "<!-- ALREADY CONVERTED BY PARSER. -->"
|
21 |
+
|
22 |
+
small_and_beautiful_theme = gr.themes.Soft(
|
23 |
+
primary_hue=gr.themes.Color(
|
24 |
+
c50="#02C160",
|
25 |
+
c100="rgba(2, 193, 96, 0.2)",
|
26 |
+
c200="#02C160",
|
27 |
+
c300="rgba(2, 193, 96, 0.32)",
|
28 |
+
c400="rgba(2, 193, 96, 0.32)",
|
29 |
+
c500="rgba(2, 193, 96, 1.0)",
|
30 |
+
c600="rgba(2, 193, 96, 1.0)",
|
31 |
+
c700="rgba(2, 193, 96, 0.32)",
|
32 |
+
c800="rgba(2, 193, 96, 0.32)",
|
33 |
+
c900="#02C160",
|
34 |
+
c950="#02C160",
|
35 |
+
),
|
36 |
+
secondary_hue=gr.themes.Color(
|
37 |
+
c50="#576b95",
|
38 |
+
c100="#576b95",
|
39 |
+
c200="#576b95",
|
40 |
+
c300="#576b95",
|
41 |
+
c400="#576b95",
|
42 |
+
c500="#576b95",
|
43 |
+
c600="#576b95",
|
44 |
+
c700="#576b95",
|
45 |
+
c800="#576b95",
|
46 |
+
c900="#576b95",
|
47 |
+
c950="#576b95",
|
48 |
+
),
|
49 |
+
neutral_hue=gr.themes.Color(
|
50 |
+
name="gray",
|
51 |
+
c50="#f9fafb",
|
52 |
+
c100="#f3f4f6",
|
53 |
+
c200="#e5e7eb",
|
54 |
+
c300="#d1d5db",
|
55 |
+
c400="#B2B2B2",
|
56 |
+
c500="#808080",
|
57 |
+
c600="#636363",
|
58 |
+
c700="#515151",
|
59 |
+
c800="#393939",
|
60 |
+
c900="#272727",
|
61 |
+
c950="#171717",
|
62 |
+
),
|
63 |
+
radius_size=gr.themes.sizes.radius_sm,
|
64 |
+
).set(
|
65 |
+
button_primary_background_fill="#06AE56",
|
66 |
+
button_primary_background_fill_dark="#06AE56",
|
67 |
+
button_primary_background_fill_hover="#07C863",
|
68 |
+
button_primary_border_color="#06AE56",
|
69 |
+
button_primary_border_color_dark="#06AE56",
|
70 |
+
button_primary_text_color="#FFFFFF",
|
71 |
+
button_primary_text_color_dark="#FFFFFF",
|
72 |
+
button_secondary_background_fill="#F2F2F2",
|
73 |
+
button_secondary_background_fill_dark="#2B2B2B",
|
74 |
+
button_secondary_text_color="#393939",
|
75 |
+
button_secondary_text_color_dark="#FFFFFF",
|
76 |
+
background_fill_primary="#F7F7F7",
|
77 |
+
background_fill_primary_dark="#1F1F1F",
|
78 |
+
block_title_text_color="*primary_500",
|
79 |
+
block_title_background_fill="*primary_100",
|
80 |
+
input_background_fill="#F6F6F6",
|
81 |
+
)
|
ChatApp/app_modules/utils.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
from __future__ import annotations
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
import html
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import mdtex2html
|
9 |
+
from markdown import markdown
|
10 |
+
from pygments import highlight
|
11 |
+
from pygments.lexers import guess_lexer, get_lexer_by_name, ClassNotFound
|
12 |
+
from pygments.formatters import HtmlFormatter
|
13 |
+
|
14 |
+
from ChatApp.app_modules.presets import ALREADY_CONVERTED_MARK
|
15 |
+
|
16 |
+
logging.basicConfig(
|
17 |
+
level=logging.INFO,
|
18 |
+
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def markdown_to_html_with_syntax_highlight(md_str):
|
23 |
+
def replacer(match):
|
24 |
+
lang = match.group(1) or "text"
|
25 |
+
code = match.group(2)
|
26 |
+
lang = lang.strip()
|
27 |
+
# print(1,lang)
|
28 |
+
if lang == "text":
|
29 |
+
lexer = guess_lexer(code)
|
30 |
+
lang = lexer.name
|
31 |
+
# print(2,lang)
|
32 |
+
try:
|
33 |
+
lexer = get_lexer_by_name(lang, stripall=True)
|
34 |
+
except ValueError:
|
35 |
+
lexer = get_lexer_by_name("python", stripall=True)
|
36 |
+
formatter = HtmlFormatter()
|
37 |
+
# print(3,lexer.name)
|
38 |
+
highlighted_code = highlight(code, lexer, formatter)
|
39 |
+
|
40 |
+
return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
|
41 |
+
|
42 |
+
code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
|
43 |
+
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
|
44 |
+
|
45 |
+
html_str = markdown(md_str)
|
46 |
+
return html_str
|
47 |
+
|
48 |
+
|
49 |
+
def normalize_markdown(md_text: str) -> str:
|
50 |
+
lines = md_text.split("\n")
|
51 |
+
normalized_lines = []
|
52 |
+
inside_list = False
|
53 |
+
|
54 |
+
for i, line in enumerate(lines):
|
55 |
+
if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
|
56 |
+
if not inside_list and i > 0 and lines[i - 1].strip() != "":
|
57 |
+
normalized_lines.append("")
|
58 |
+
inside_list = True
|
59 |
+
normalized_lines.append(line)
|
60 |
+
elif inside_list and line.strip() == "":
|
61 |
+
if i < len(lines) - 1 and not re.match(
|
62 |
+
r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
|
63 |
+
):
|
64 |
+
normalized_lines.append(line)
|
65 |
+
continue
|
66 |
+
else:
|
67 |
+
inside_list = False
|
68 |
+
normalized_lines.append(line)
|
69 |
+
|
70 |
+
return "\n".join(normalized_lines)
|
71 |
+
|
72 |
+
|
73 |
+
def convert_mdtext(md_text):
|
74 |
+
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
75 |
+
inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
|
76 |
+
code_blocks = code_block_pattern.findall(md_text)
|
77 |
+
non_code_parts = code_block_pattern.split(md_text)[::2]
|
78 |
+
|
79 |
+
result = []
|
80 |
+
for non_code, code in zip(non_code_parts, code_blocks + [""]):
|
81 |
+
if non_code.strip():
|
82 |
+
non_code = normalize_markdown(non_code)
|
83 |
+
if inline_code_pattern.search(non_code):
|
84 |
+
result.append(markdown(non_code, extensions=["tables"]))
|
85 |
+
else:
|
86 |
+
result.append(mdtex2html.convert(non_code, extensions=["tables"]))
|
87 |
+
if code.strip():
|
88 |
+
code = f"\n```{code}\n\n```"
|
89 |
+
code = markdown_to_html_with_syntax_highlight(code)
|
90 |
+
result.append(code)
|
91 |
+
result = "".join(result)
|
92 |
+
result += ALREADY_CONVERTED_MARK
|
93 |
+
return result
|
94 |
+
|
95 |
+
|
96 |
+
def convert_asis(userinput):
|
97 |
+
return (
|
98 |
+
f'<p style="white-space:pre-wrap;">{html.escape(userinput)}</p>'
|
99 |
+
+ ALREADY_CONVERTED_MARK
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def detect_converted_mark(userinput):
|
104 |
+
if userinput.endswith(ALREADY_CONVERTED_MARK):
|
105 |
+
return True
|
106 |
+
else:
|
107 |
+
return False
|
108 |
+
|
109 |
+
|
110 |
+
def detect_language(code):
|
111 |
+
if code.startswith("\n"):
|
112 |
+
first_line = ""
|
113 |
+
else:
|
114 |
+
first_line = code.strip().split("\n", 1)[0]
|
115 |
+
language = first_line.lower() if first_line else ""
|
116 |
+
code_without_language = code[len(first_line) :].lstrip() if first_line else code
|
117 |
+
return language, code_without_language
|
118 |
+
|
119 |
+
|
120 |
+
def convert_to_markdown(text):
|
121 |
+
text = text.replace("$", "$")
|
122 |
+
|
123 |
+
def replace_leading_tabs_and_spaces(line):
|
124 |
+
new_line = []
|
125 |
+
|
126 |
+
for char in line:
|
127 |
+
if char == "\t":
|
128 |
+
new_line.append("	")
|
129 |
+
elif char == " ":
|
130 |
+
new_line.append(" ")
|
131 |
+
else:
|
132 |
+
break
|
133 |
+
return "".join(new_line) + line[len(new_line) :]
|
134 |
+
|
135 |
+
markdown_text = ""
|
136 |
+
lines = text.split("\n")
|
137 |
+
in_code_block = False
|
138 |
+
|
139 |
+
for line in lines:
|
140 |
+
if in_code_block is False and line.startswith("```"):
|
141 |
+
in_code_block = True
|
142 |
+
markdown_text += f"{line}\n"
|
143 |
+
elif in_code_block is True and line.startswith("```"):
|
144 |
+
in_code_block = False
|
145 |
+
markdown_text += f"{line}\n"
|
146 |
+
elif in_code_block:
|
147 |
+
markdown_text += f"{line}\n"
|
148 |
+
else:
|
149 |
+
line = replace_leading_tabs_and_spaces(line)
|
150 |
+
line = re.sub(r"^(#)", r"\\\1", line)
|
151 |
+
markdown_text += f"{line} \n"
|
152 |
+
|
153 |
+
return markdown_text
|
154 |
+
|
155 |
+
|
156 |
+
def add_language_tag(text):
|
157 |
+
def detect_language(code_block):
|
158 |
+
try:
|
159 |
+
lexer = guess_lexer(code_block)
|
160 |
+
return lexer.name.lower()
|
161 |
+
except ClassNotFound:
|
162 |
+
return ""
|
163 |
+
|
164 |
+
code_block_pattern = re.compile(r"(```)(\w*\n[^`]+```)", re.MULTILINE)
|
165 |
+
|
166 |
+
def replacement(match):
|
167 |
+
code_block = match.group(2)
|
168 |
+
if match.group(2).startswith("\n"):
|
169 |
+
language = detect_language(code_block)
|
170 |
+
if language:
|
171 |
+
return f"```{language}{code_block}```"
|
172 |
+
else:
|
173 |
+
return f"```\n{code_block}```"
|
174 |
+
else:
|
175 |
+
return match.group(1) + code_block + "```"
|
176 |
+
|
177 |
+
text2 = code_block_pattern.sub(replacement, text)
|
178 |
+
return text2
|
179 |
+
|
180 |
+
|
181 |
+
def delete_last_conversation(chatbot, history):
|
182 |
+
if len(chatbot) > 0:
|
183 |
+
chatbot.pop()
|
184 |
+
|
185 |
+
if len(history) > 0:
|
186 |
+
history.pop()
|
187 |
+
|
188 |
+
return (
|
189 |
+
chatbot,
|
190 |
+
history,
|
191 |
+
"Delete Done",
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
def reset_state():
|
196 |
+
return [], [], "Reset Done"
|
197 |
+
|
198 |
+
|
199 |
+
def reset_textbox():
|
200 |
+
return gr.update(value=""), ""
|
201 |
+
|
202 |
+
|
203 |
+
def cancel_outputing():
|
204 |
+
return "Stop Done"
|
205 |
+
|
206 |
+
|
207 |
+
def transfer_input(inputs):
|
208 |
+
return (
|
209 |
+
inputs,
|
210 |
+
gr.update(value=""),
|
211 |
+
gr.Button.update(visible=True),
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
class State:
|
216 |
+
interrupted = False
|
217 |
+
|
218 |
+
def interrupt(self):
|
219 |
+
self.interrupted = True
|
220 |
+
|
221 |
+
def recover(self):
|
222 |
+
self.interrupted = False
|
223 |
+
|
224 |
+
|
225 |
+
shared_state = State()
|
226 |
+
|
227 |
+
|
228 |
+
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
|
229 |
+
for stop_word in stop_words:
|
230 |
+
if s.endswith(stop_word):
|
231 |
+
return True
|
232 |
+
for i in range(1, len(stop_word)):
|
233 |
+
if s.endswith(stop_word[:i]):
|
234 |
+
return True
|
235 |
+
return False
|
ChatApp/assets/custom.css
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
:root {
|
2 |
+
--chatbot-color-light: #F3F3F3;
|
3 |
+
--chatbot-color-dark: #121111;
|
4 |
+
}
|
5 |
+
|
6 |
+
/* status_display */
|
7 |
+
#status_display {
|
8 |
+
display: flex;
|
9 |
+
min-height: 2.5em;
|
10 |
+
align-items: flex-end;
|
11 |
+
justify-content: flex-end;
|
12 |
+
}
|
13 |
+
|
14 |
+
#status_display p {
|
15 |
+
font-size: .85em;
|
16 |
+
font-family: monospace;
|
17 |
+
color: var(--body-text-color-subdued);
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
/* usage_display */
|
23 |
+
#usage_display {
|
24 |
+
height: 1em;
|
25 |
+
}
|
26 |
+
|
27 |
+
#usage_display p {
|
28 |
+
padding: 0 1em;
|
29 |
+
font-size: .85em;
|
30 |
+
font-family: monospace;
|
31 |
+
color: var(--body-text-color-subdued);
|
32 |
+
}
|
33 |
+
|
34 |
+
/* list */
|
35 |
+
ol:not(.options),
|
36 |
+
ul:not(.options) {
|
37 |
+
padding-inline-start: 2em !important;
|
38 |
+
}
|
39 |
+
|
40 |
+
/* Thank @Keldos-Li for fixing it */
|
41 |
+
/* Light mode (default) */
|
42 |
+
#chuanhu_chatbot {
|
43 |
+
background-color: var(--chatbot-color-light) !important;
|
44 |
+
color: #000000 !important;
|
45 |
+
}
|
46 |
+
|
47 |
+
[data-testid="bot"] {
|
48 |
+
background-color: #FFFFFF !important;
|
49 |
+
}
|
50 |
+
|
51 |
+
[data-testid="user"] {
|
52 |
+
background-color: #95EC69 !important;
|
53 |
+
}
|
54 |
+
|
55 |
+
/* Dark mode */
|
56 |
+
.dark #chuanhu_chatbot {
|
57 |
+
background-color: var(--chatbot-color-dark) !important;
|
58 |
+
color: #FFFFFF !important;
|
59 |
+
}
|
60 |
+
|
61 |
+
.dark [data-testid="bot"] {
|
62 |
+
background-color: #2C2C2C !important;
|
63 |
+
}
|
64 |
+
|
65 |
+
.dark [data-testid="user"] {
|
66 |
+
background-color: #26B561 !important;
|
67 |
+
}
|
68 |
+
|
69 |
+
#chuanhu_chatbot {
|
70 |
+
height: 100%;
|
71 |
+
min-height: 400px;
|
72 |
+
}
|
73 |
+
|
74 |
+
[class *="message"] {
|
75 |
+
border-radius: var(--radius-xl) !important;
|
76 |
+
border: none;
|
77 |
+
padding: var(--spacing-xl) !important;
|
78 |
+
font-size: var(--text-md) !important;
|
79 |
+
line-height: var(--line-md) !important;
|
80 |
+
min-height: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl));
|
81 |
+
min-width: calc(var(--text-md)*var(--line-md) + 2*var(--spacing-xl));
|
82 |
+
}
|
83 |
+
|
84 |
+
[data-testid="bot"] {
|
85 |
+
max-width: 85%;
|
86 |
+
border-bottom-left-radius: 0 !important;
|
87 |
+
}
|
88 |
+
|
89 |
+
[data-testid="user"] {
|
90 |
+
max-width: 85%;
|
91 |
+
width: auto !important;
|
92 |
+
border-bottom-right-radius: 0 !important;
|
93 |
+
}
|
94 |
+
|
95 |
+
/* Table */
|
96 |
+
table {
|
97 |
+
margin: 1em 0;
|
98 |
+
border-collapse: collapse;
|
99 |
+
empty-cells: show;
|
100 |
+
}
|
101 |
+
|
102 |
+
td,
|
103 |
+
th {
|
104 |
+
border: 1.2px solid var(--border-color-primary) !important;
|
105 |
+
padding: 0.2em;
|
106 |
+
}
|
107 |
+
|
108 |
+
thead {
|
109 |
+
background-color: rgba(175, 184, 193, 0.2);
|
110 |
+
}
|
111 |
+
|
112 |
+
thead th {
|
113 |
+
padding: .5em .2em;
|
114 |
+
}
|
115 |
+
|
116 |
+
/* Inline code */
|
117 |
+
#chuanhu_chatbot code {
|
118 |
+
display: inline;
|
119 |
+
white-space: break-spaces;
|
120 |
+
border-radius: 6px;
|
121 |
+
margin: 0 2px 0 2px;
|
122 |
+
padding: .2em .4em .1em .4em;
|
123 |
+
background-color: rgba(175, 184, 193, 0.2);
|
124 |
+
}
|
125 |
+
|
126 |
+
/* Code block */
|
127 |
+
#chuanhu_chatbot pre code {
|
128 |
+
display: block;
|
129 |
+
overflow: auto;
|
130 |
+
white-space: pre;
|
131 |
+
background-color: hsla(0, 0%, 0%, 80%) !important;
|
132 |
+
border-radius: 10px;
|
133 |
+
padding: 1.4em 1.2em 0em 1.4em;
|
134 |
+
margin: 1.2em 2em 1.2em 0.5em;
|
135 |
+
color: #FFFF;
|
136 |
+
box-shadow: 6px 6px 16px hsla(0, 0%, 0%, 0.2);
|
137 |
+
}
|
138 |
+
|
139 |
+
/* Hightlight */
|
140 |
+
#chuanhu_chatbot .highlight {
|
141 |
+
background-color: transparent
|
142 |
+
}
|
143 |
+
|
144 |
+
#chuanhu_chatbot .highlight .hll {
|
145 |
+
background-color: #49483e
|
146 |
+
}
|
147 |
+
|
148 |
+
#chuanhu_chatbot .highlight .c {
|
149 |
+
color: #75715e
|
150 |
+
}
|
151 |
+
|
152 |
+
/* Comment */
|
153 |
+
#chuanhu_chatbot .highlight .err {
|
154 |
+
color: #960050;
|
155 |
+
background-color: #1e0010
|
156 |
+
}
|
157 |
+
|
158 |
+
/* Error */
|
159 |
+
#chuanhu_chatbot .highlight .k {
|
160 |
+
color: #66d9ef
|
161 |
+
}
|
162 |
+
|
163 |
+
/* Keyword */
|
164 |
+
#chuanhu_chatbot .highlight .l {
|
165 |
+
color: #ae81ff
|
166 |
+
}
|
167 |
+
|
168 |
+
/* Literal */
|
169 |
+
#chuanhu_chatbot .highlight .n {
|
170 |
+
color: #8828f2
|
171 |
+
}
|
172 |
+
|
173 |
+
/* Name */
|
174 |
+
#chuanhu_chatbot .highlight .o {
|
175 |
+
color: #f92672
|
176 |
+
}
|
177 |
+
|
178 |
+
/* Operator */
|
179 |
+
#chuanhu_chatbot .highlight .p {
|
180 |
+
color: #482822
|
181 |
+
}
|
182 |
+
|
183 |
+
/* Punctuation */
|
184 |
+
#chuanhu_chatbot .highlight .ch {
|
185 |
+
color: #75715e
|
186 |
+
}
|
187 |
+
|
188 |
+
/* Comment.Hashbang */
|
189 |
+
#chuanhu_chatbot .highlight .cm {
|
190 |
+
color: #75715e
|
191 |
+
}
|
192 |
+
|
193 |
+
/* Comment.Multiline */
|
194 |
+
#chuanhu_chatbot .highlight .cp {
|
195 |
+
color: #75715e
|
196 |
+
}
|
197 |
+
|
198 |
+
/* Comment.Preproc */
|
199 |
+
#chuanhu_chatbot .highlight .cpf {
|
200 |
+
color: #75715e
|
201 |
+
}
|
202 |
+
|
203 |
+
/* Comment.PreprocFile */
|
204 |
+
#chuanhu_chatbot .highlight .c1 {
|
205 |
+
color: #75715e
|
206 |
+
}
|
207 |
+
|
208 |
+
/* Comment.Single */
|
209 |
+
#chuanhu_chatbot .highlight .cs {
|
210 |
+
color: #75715e
|
211 |
+
}
|
212 |
+
|
213 |
+
/* Comment.Special */
|
214 |
+
#chuanhu_chatbot .highlight .gd {
|
215 |
+
color: #f92672
|
216 |
+
}
|
217 |
+
|
218 |
+
/* Generic.Deleted */
|
219 |
+
#chuanhu_chatbot .highlight .ge {
|
220 |
+
font-style: italic
|
221 |
+
}
|
222 |
+
|
223 |
+
/* Generic.Emph */
|
224 |
+
#chuanhu_chatbot .highlight .gi {
|
225 |
+
color: #a6e22e
|
226 |
+
}
|
227 |
+
|
228 |
+
/* Generic.Inserted */
|
229 |
+
#chuanhu_chatbot .highlight .gs {
|
230 |
+
font-weight: bold
|
231 |
+
}
|
232 |
+
|
233 |
+
/* Generic.Strong */
|
234 |
+
#chuanhu_chatbot .highlight .gu {
|
235 |
+
color: #75715e
|
236 |
+
}
|
237 |
+
|
238 |
+
/* Generic.Subheading */
|
239 |
+
#chuanhu_chatbot .highlight .kc {
|
240 |
+
color: #66d9ef
|
241 |
+
}
|
242 |
+
|
243 |
+
/* Keyword.Constant */
|
244 |
+
#chuanhu_chatbot .highlight .kd {
|
245 |
+
color: #66d9ef
|
246 |
+
}
|
247 |
+
|
248 |
+
/* Keyword.Declaration */
|
249 |
+
#chuanhu_chatbot .highlight .kn {
|
250 |
+
color: #f92672
|
251 |
+
}
|
252 |
+
|
253 |
+
/* Keyword.Namespace */
|
254 |
+
#chuanhu_chatbot .highlight .kp {
|
255 |
+
color: #66d9ef
|
256 |
+
}
|
257 |
+
|
258 |
+
/* Keyword.Pseudo */
|
259 |
+
#chuanhu_chatbot .highlight .kr {
|
260 |
+
color: #66d9ef
|
261 |
+
}
|
262 |
+
|
263 |
+
/* Keyword.Reserved */
|
264 |
+
#chuanhu_chatbot .highlight .kt {
|
265 |
+
color: #66d9ef
|
266 |
+
}
|
267 |
+
|
268 |
+
/* Keyword.Type */
|
269 |
+
#chuanhu_chatbot .highlight .ld {
|
270 |
+
color: #162b74
|
271 |
+
}
|
272 |
+
|
273 |
+
/* Literal.Date */
|
274 |
+
#chuanhu_chatbot .highlight .m {
|
275 |
+
color: #ae81ff
|
276 |
+
}
|
277 |
+
|
278 |
+
/* Literal.Number */
|
279 |
+
#chuanhu_chatbot .highlight .s {
|
280 |
+
color: #062b84
|
281 |
+
}
|
282 |
+
|
283 |
+
/* Literal.String */
|
284 |
+
#chuanhu_chatbot .highlight .na {
|
285 |
+
color: #a6e22e
|
286 |
+
}
|
287 |
+
|
288 |
+
/* Name.Attribute */
|
289 |
+
#chuanhu_chatbot .highlight .nb {
|
290 |
+
color: #482822
|
291 |
+
}
|
292 |
+
|
293 |
+
/* Name.Builtin */
|
294 |
+
#chuanhu_chatbot .highlight .nc {
|
295 |
+
color: #a6e22e
|
296 |
+
}
|
297 |
+
|
298 |
+
/* Name.Class */
|
299 |
+
#chuanhu_chatbot .highlight .no {
|
300 |
+
color: #66d9ef
|
301 |
+
}
|
302 |
+
|
303 |
+
/* Name.Constant */
|
304 |
+
#chuanhu_chatbot .highlight .nd {
|
305 |
+
color: #a6e22e
|
306 |
+
}
|
307 |
+
|
308 |
+
/* Name.Decorator */
|
309 |
+
#chuanhu_chatbot .highlight .ni {
|
310 |
+
color: #482822
|
311 |
+
}
|
312 |
+
|
313 |
+
/* Name.Entity */
|
314 |
+
#chuanhu_chatbot .highlight .ne {
|
315 |
+
color: #a6e22e
|
316 |
+
}
|
317 |
+
|
318 |
+
/* Name.Exception */
|
319 |
+
#chuanhu_chatbot .highlight .nf {
|
320 |
+
color: #a6e22e
|
321 |
+
}
|
322 |
+
|
323 |
+
/* Name.Function */
|
324 |
+
#chuanhu_chatbot .highlight .nl {
|
325 |
+
color: #1818f2
|
326 |
+
}
|
327 |
+
|
328 |
+
/* Name.Label */
|
329 |
+
#chuanhu_chatbot .highlight .nn {
|
330 |
+
color: #482822
|
331 |
+
}
|
332 |
+
|
333 |
+
/* Name.Namespace */
|
334 |
+
#chuanhu_chatbot .highlight .nx {
|
335 |
+
color: #a6e22e
|
336 |
+
}
|
337 |
+
|
338 |
+
/* Name.Other */
|
339 |
+
#chuanhu_chatbot .highlight .py {
|
340 |
+
color: #482822
|
341 |
+
}
|
342 |
+
|
343 |
+
/* Name.Property */
|
344 |
+
#chuanhu_chatbot .highlight .nt {
|
345 |
+
color: #f92672
|
346 |
+
}
|
347 |
+
|
348 |
+
/* Name.Tag */
|
349 |
+
#chuanhu_chatbot .highlight .nv {
|
350 |
+
color: #482822
|
351 |
+
}
|
352 |
+
|
353 |
+
/* Name.Variable */
|
354 |
+
#chuanhu_chatbot .highlight .ow {
|
355 |
+
color: #f92672
|
356 |
+
}
|
357 |
+
|
358 |
+
/* Operator.Word */
|
359 |
+
#chuanhu_chatbot .highlight .w {
|
360 |
+
color: #482822
|
361 |
+
}
|
362 |
+
|
363 |
+
/* Text.Whitespace */
|
364 |
+
#chuanhu_chatbot .highlight .mb {
|
365 |
+
color: #ae81ff
|
366 |
+
}
|
367 |
+
|
368 |
+
/* Literal.Number.Bin */
|
369 |
+
#chuanhu_chatbot .highlight .mf {
|
370 |
+
color: #ae81ff
|
371 |
+
}
|
372 |
+
|
373 |
+
/* Literal.Number.Float */
|
374 |
+
#chuanhu_chatbot .highlight .mh {
|
375 |
+
color: #ae81ff
|
376 |
+
}
|
377 |
+
|
378 |
+
/* Literal.Number.Hex */
|
379 |
+
#chuanhu_chatbot .highlight .mi {
|
380 |
+
color: #ae81ff
|
381 |
+
}
|
382 |
+
|
383 |
+
/* Literal.Number.Integer */
|
384 |
+
#chuanhu_chatbot .highlight .mo {
|
385 |
+
color: #ae81ff
|
386 |
+
}
|
387 |
+
|
388 |
+
/* Literal.Number.Oct */
|
389 |
+
#chuanhu_chatbot .highlight .sa {
|
390 |
+
color: #162b74
|
391 |
+
}
|
392 |
+
|
393 |
+
/* Literal.String.Affix */
|
394 |
+
#chuanhu_chatbot .highlight .sb {
|
395 |
+
color: #161b74
|
396 |
+
}
|
397 |
+
|
398 |
+
/* Literal.String.Backtick */
|
399 |
+
#chuanhu_chatbot .highlight .sc {
|
400 |
+
color: #162b74
|
401 |
+
}
|
402 |
+
|
403 |
+
/* Literal.String.Char */
|
404 |
+
#chuanhu_chatbot .highlight .dl {
|
405 |
+
color: #162b74
|
406 |
+
}
|
407 |
+
|
408 |
+
/* Literal.String.Delimiter */
|
409 |
+
#chuanhu_chatbot .highlight .sd {
|
410 |
+
color: #162b74
|
411 |
+
}
|
412 |
+
|
413 |
+
/* Literal.String.Doc */
|
414 |
+
#chuanhu_chatbot .highlight .s2 {
|
415 |
+
color: #162b74
|
416 |
+
}
|
417 |
+
|
418 |
+
/* Literal.String.Double */
|
419 |
+
#chuanhu_chatbot .highlight .se {
|
420 |
+
color: #ae81ff
|
421 |
+
}
|
422 |
+
|
423 |
+
/* Literal.String.Escape */
|
424 |
+
#chuanhu_chatbot .highlight .sh {
|
425 |
+
color: #162b74
|
426 |
+
}
|
427 |
+
|
428 |
+
/* Literal.String.Heredoc */
|
429 |
+
#chuanhu_chatbot .highlight .si {
|
430 |
+
color: #162b74
|
431 |
+
}
|
432 |
+
|
433 |
+
/* Literal.String.Interpol */
|
434 |
+
#chuanhu_chatbot .highlight .sx {
|
435 |
+
color: #162b74
|
436 |
+
}
|
437 |
+
|
438 |
+
/* Literal.String.Other */
|
439 |
+
#chuanhu_chatbot .highlight .sr {
|
440 |
+
color: #162b74
|
441 |
+
}
|
442 |
+
|
443 |
+
/* Literal.String.Regex */
|
444 |
+
#chuanhu_chatbot .highlight .s1 {
|
445 |
+
color: #162b74
|
446 |
+
}
|
447 |
+
|
448 |
+
/* Literal.String.Single */
|
449 |
+
#chuanhu_chatbot .highlight .ss {
|
450 |
+
color: #162b74
|
451 |
+
}
|
452 |
+
|
453 |
+
/* Literal.String.Symbol */
|
454 |
+
#chuanhu_chatbot .highlight .bp {
|
455 |
+
color: #482822
|
456 |
+
}
|
457 |
+
|
458 |
+
/* Name.Builtin.Pseudo */
|
459 |
+
#chuanhu_chatbot .highlight .fm {
|
460 |
+
color: #a6e22e
|
461 |
+
}
|
462 |
+
|
463 |
+
/* Name.Function.Magic */
|
464 |
+
#chuanhu_chatbot .highlight .vc {
|
465 |
+
color: #482822
|
466 |
+
}
|
467 |
+
|
468 |
+
/* Name.Variable.Class */
|
469 |
+
#chuanhu_chatbot .highlight .vg {
|
470 |
+
color: #482822
|
471 |
+
}
|
472 |
+
|
473 |
+
/* Name.Variable.Global */
|
474 |
+
#chuanhu_chatbot .highlight .vi {
|
475 |
+
color: #482822
|
476 |
+
}
|
477 |
+
|
478 |
+
/* Name.Variable.Instance */
|
479 |
+
#chuanhu_chatbot .highlight .vm {
|
480 |
+
color: #482822
|
481 |
+
}
|
482 |
+
|
483 |
+
/* Name.Variable.Magic */
|
484 |
+
#chuanhu_chatbot .highlight .il {
|
485 |
+
color: #ae81ff
|
486 |
+
}
|
487 |
+
|
488 |
+
/* Literal.Number.Integer.Long */
|
ChatApp/assets/custom.js
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
// custom javascript here
|
ChatApp/interface/__pycache__/base_interface.cpython-39.pyc
ADDED
Binary file (574 Bytes). View file
|
|
ChatApp/interface/__pycache__/empty_stub_interface.cpython-39.pyc
ADDED
Binary file (1.33 kB). View file
|
|
ChatApp/interface/__pycache__/hddr_llama_onnx_interface.cpython-39.pyc
ADDED
Binary file (8.94 kB). View file
|
|
ChatApp/interface/base_interface.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class BaseLLMInterface:
|
2 |
+
def __init__(self):
|
3 |
+
pass
|
4 |
+
|
5 |
+
def foo(self):
|
6 |
+
pass
|
ChatApp/interface/empty_stub_interface.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app_modules.utils import logging
|
2 |
+
|
3 |
+
|
4 |
+
class EmptyStubInterface:
|
5 |
+
def __init__(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
def initialize(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def shutdown(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
def predict(
|
15 |
+
self,
|
16 |
+
text,
|
17 |
+
chatbot,
|
18 |
+
history,
|
19 |
+
top_p,
|
20 |
+
temperature,
|
21 |
+
max_length_tokens,
|
22 |
+
max_context_length_tokens,
|
23 |
+
):
|
24 |
+
logging.info("hi there")
|
25 |
+
logging.info("-" * 100)
|
26 |
+
# yield chatbot,history,"Empty context."
|
27 |
+
yield [[text, "No Model Found"]], [], "No Model Found"
|
28 |
+
|
29 |
+
def retry(
|
30 |
+
self,
|
31 |
+
text,
|
32 |
+
chatbot,
|
33 |
+
history,
|
34 |
+
top_p,
|
35 |
+
temperature,
|
36 |
+
max_length_tokens,
|
37 |
+
max_context_length_tokens,
|
38 |
+
):
|
39 |
+
yield chatbot, history, "Empty context"
|
ChatApp/interface/hddr_llama_onnx_interface.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import onnxruntime
|
3 |
+
import numpy as np
|
4 |
+
from sentencepiece import SentencePieceProcessor
|
5 |
+
from typing import List
|
6 |
+
import os
|
7 |
+
import logging
|
8 |
+
import gc
|
9 |
+
|
10 |
+
from .base_interface import BaseLLMInterface
|
11 |
+
|
12 |
+
from ChatApp.app_modules.utils import (
|
13 |
+
is_stop_word_or_prefix,
|
14 |
+
convert_to_markdown,
|
15 |
+
shared_state,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class Tokenizer:
|
20 |
+
def __init__(self, model_path: str):
|
21 |
+
# reload tokenizer
|
22 |
+
assert os.path.isfile(model_path), model_path
|
23 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
24 |
+
|
25 |
+
# BOS / EOS token IDs
|
26 |
+
self.n_words: int = self.sp_model.vocab_size()
|
27 |
+
self.bos_id: int = self.sp_model.bos_id()
|
28 |
+
self.eos_id: int = self.sp_model.eos_id()
|
29 |
+
self.pad_id: int = self.sp_model.pad_id()
|
30 |
+
|
31 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
32 |
+
|
33 |
+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
|
34 |
+
assert type(s) is str
|
35 |
+
t = self.sp_model.encode(s)
|
36 |
+
if bos:
|
37 |
+
t = [self.bos_id] + t
|
38 |
+
if eos:
|
39 |
+
t = t + [self.eos_id]
|
40 |
+
return t
|
41 |
+
|
42 |
+
def decode(self, t: List[int]) -> str:
|
43 |
+
return self.sp_model.decode(t)
|
44 |
+
|
45 |
+
|
46 |
+
class LlamaOnnxInterface(BaseLLMInterface):
|
47 |
+
def __init__(self, onnx_file="", embedding_file="", tokenizer_path=""):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.onnx_file = onnx_file
|
51 |
+
self.embedding_file = embedding_file
|
52 |
+
self.tokenizer_path = tokenizer_path
|
53 |
+
|
54 |
+
self.total_count = 0
|
55 |
+
|
56 |
+
def initialize(self):
|
57 |
+
# Create the ONNX session
|
58 |
+
|
59 |
+
logging.info(f"Creating ONNX session for [{self.onnx_file}]")
|
60 |
+
options = onnxruntime.SessionOptions()
|
61 |
+
self.llm_session = onnxruntime.InferenceSession(
|
62 |
+
self.onnx_file,
|
63 |
+
sess_options=options,
|
64 |
+
providers=[
|
65 |
+
"DmlExecutionProvider",
|
66 |
+
"CUDAExecutionProvider",
|
67 |
+
"CPUExecutionProvider",
|
68 |
+
],
|
69 |
+
)
|
70 |
+
|
71 |
+
# get the data type used by the model
|
72 |
+
data_type_str = self.llm_session.get_inputs()[0].type
|
73 |
+
if data_type_str == "tensor(float16)":
|
74 |
+
self.data_type = np.float16
|
75 |
+
elif data_type_str == "tensor(float32)":
|
76 |
+
self.data_type = np.float32
|
77 |
+
else:
|
78 |
+
raise Exception(f"Unknown data type {data_type_str}")
|
79 |
+
|
80 |
+
logging.info(f"Detected Data Type [{self.data_type}]")
|
81 |
+
|
82 |
+
# Get the relevant shapes so we can create the inputs
|
83 |
+
for inputs_meta in self.llm_session._inputs_meta:
|
84 |
+
if inputs_meta.name == "x":
|
85 |
+
x_shape = inputs_meta.shape
|
86 |
+
elif inputs_meta.name == "attn_mask":
|
87 |
+
attn_mask_shape = inputs_meta.shape
|
88 |
+
elif inputs_meta.name == "k_cache":
|
89 |
+
k_cache_shape = inputs_meta.shape
|
90 |
+
|
91 |
+
self.hidden_size = x_shape[2]
|
92 |
+
self.max_seq_len = attn_mask_shape[1]
|
93 |
+
self.n_layers = k_cache_shape[1]
|
94 |
+
self.n_heads = k_cache_shape[3]
|
95 |
+
|
96 |
+
# Initialize the tokenizer and produce the initial tokens.
|
97 |
+
self.tokenizer = Tokenizer(model_path=self.tokenizer_path)
|
98 |
+
|
99 |
+
# create the embedding layer.
|
100 |
+
logging.info(
|
101 |
+
f"Creating the Embedding Layer. Size [{self.tokenizer.n_words}, {self.hidden_size}]"
|
102 |
+
)
|
103 |
+
self.embeddingLayer = torch.nn.Embedding(
|
104 |
+
self.tokenizer.n_words, self.hidden_size
|
105 |
+
)
|
106 |
+
|
107 |
+
# rg hack - dont have the embeddings.pth file - taking it from the original llama model
|
108 |
+
d = torch.load(self.embedding_file)
|
109 |
+
self.embeddingLayer.load_state_dict(d)
|
110 |
+
self.embeddingLayer.eval()
|
111 |
+
|
112 |
+
# Create the attention mask.
|
113 |
+
self.attn_mask = -10000.0 * torch.triu(
|
114 |
+
torch.ones(attn_mask_shape), diagonal=1
|
115 |
+
).cpu().detach().numpy().astype(self.data_type)
|
116 |
+
|
117 |
+
# Create the K and V caches.
|
118 |
+
self.head_dim = int(self.hidden_size / self.n_heads)
|
119 |
+
self.k_cache = np.zeros(
|
120 |
+
[1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim],
|
121 |
+
dtype=self.data_type,
|
122 |
+
)
|
123 |
+
self.v_cache = np.zeros(
|
124 |
+
[1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim],
|
125 |
+
dtype=self.data_type,
|
126 |
+
)
|
127 |
+
|
128 |
+
def shutdown(self):
|
129 |
+
pass
|
130 |
+
|
131 |
+
def generate_prompt_with_history(self, text, history, tokenizer, max_length=2048):
|
132 |
+
prompt = "[|Human|]Hey there I am a human that would like to have\
|
133 |
+
a conversation with you.\n[|AI|]Sure, I am happy to answer most questions\
|
134 |
+
\n[|Human|]Great, I insist that we take turns.\n[|AI|]I agree, we should\
|
135 |
+
take turns.\n[|Human|]Great, can we also keep answers short\n[|AI|]Yes, \
|
136 |
+
short answers are usually best"
|
137 |
+
|
138 |
+
history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0], x[1]) for x in history]
|
139 |
+
history.append("\n[|Human|]{}\n[|AI|]".format(text))
|
140 |
+
history_text = ""
|
141 |
+
flag = False
|
142 |
+
for x in history[::-1]:
|
143 |
+
# tokens = self.tokenizer.encode(text, bos=True, eos=False)
|
144 |
+
if (
|
145 |
+
len(
|
146 |
+
self.tokenizer.encode(
|
147 |
+
prompt + history_text + x, bos=True, eos=False
|
148 |
+
)
|
149 |
+
)
|
150 |
+
<= max_length
|
151 |
+
):
|
152 |
+
history_text = x + history_text
|
153 |
+
flag = True
|
154 |
+
else:
|
155 |
+
break
|
156 |
+
if flag:
|
157 |
+
return prompt + history_text, torch.tensor(
|
158 |
+
self.tokenizer.encode(prompt + history_text, bos=True, eos=False)
|
159 |
+
).unsqueeze(0)
|
160 |
+
else:
|
161 |
+
return None
|
162 |
+
|
163 |
+
def sample_logits(
|
164 |
+
self,
|
165 |
+
logits: np.ndarray,
|
166 |
+
sampling_method: str = "greedy",
|
167 |
+
sampling_value: float = None,
|
168 |
+
temperature: float = 1.0,
|
169 |
+
) -> np.ndarray:
|
170 |
+
if temperature == 0 or sampling_method == "greedy":
|
171 |
+
next_token = np.argmax(logits, axis=-1).astype(np.int64)
|
172 |
+
|
173 |
+
elif sampling_method == "top_k" or sampling_method == "top_p":
|
174 |
+
assert sampling_value is not None
|
175 |
+
|
176 |
+
# temperature, converting to probabilities and sorting are common to both top-k and top-p
|
177 |
+
# convert logits to 32-bit float to avoid numerical issues with np.exp
|
178 |
+
logits = logits.astype(np.float32)
|
179 |
+
# Scale the logits by the temperature
|
180 |
+
logits /= temperature
|
181 |
+
# Convert logits to probabilities
|
182 |
+
probs = np.exp(logits) / np.sum(np.exp(logits))
|
183 |
+
# Sort th probabilities and indexes
|
184 |
+
sorted_probs = np.sort(probs)[:, ::-1]
|
185 |
+
sorted_indices = np.argsort(probs)[:, ::-1]
|
186 |
+
|
187 |
+
# find the index of interest for each of the methods.
|
188 |
+
if sampling_method == "top_k":
|
189 |
+
index_of_interest = int(sampling_value)
|
190 |
+
elif sampling_method == "top_p":
|
191 |
+
p = sampling_value
|
192 |
+
cumulative_probs = np.cumsum(sorted_probs, axis=-1)
|
193 |
+
# find the value of the first cumalitive probability that exceeds p
|
194 |
+
for index_of_interest, cumulative_prob in enumerate(
|
195 |
+
cumulative_probs[0]
|
196 |
+
):
|
197 |
+
if cumulative_prob > p:
|
198 |
+
break
|
199 |
+
|
200 |
+
probs_of_interest = sorted_probs[:, : index_of_interest + 1]
|
201 |
+
indices_of_interest = sorted_indices[:, : index_of_interest + 1]
|
202 |
+
# Normalize the probabilities and select the next token
|
203 |
+
probs_of_interest /= np.sum(probs_of_interest)
|
204 |
+
next_token = np.array(
|
205 |
+
[np.random.choice(indices_of_interest[0], p=probs_of_interest[0])]
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
raise Exception(f"Unknown sampling method {sampling_method}")
|
209 |
+
|
210 |
+
return next_token
|
211 |
+
|
212 |
+
def greedy_search(
|
213 |
+
self,
|
214 |
+
input_ids,
|
215 |
+
model,
|
216 |
+
tokenizer,
|
217 |
+
stop_words: list,
|
218 |
+
max_length: int,
|
219 |
+
temperature: float = 1.0,
|
220 |
+
top_p: float = 1.0,
|
221 |
+
top_k: int = 25,
|
222 |
+
):
|
223 |
+
generated_tokens = []
|
224 |
+
pos = np.array(0)
|
225 |
+
|
226 |
+
x = (
|
227 |
+
self.embeddingLayer(torch.tensor(input_ids))
|
228 |
+
.detach()
|
229 |
+
.cpu()
|
230 |
+
.numpy()
|
231 |
+
.astype(self.data_type)
|
232 |
+
)
|
233 |
+
|
234 |
+
for i in range(max_length):
|
235 |
+
results = self.llm_session.run(
|
236 |
+
None,
|
237 |
+
{
|
238 |
+
"x": x,
|
239 |
+
"attn_mask": self.attn_mask,
|
240 |
+
"k_cache": self.k_cache[:, :, :pos],
|
241 |
+
"v_cache": self.v_cache[:, :, :pos],
|
242 |
+
"pos": pos.astype(np.int64),
|
243 |
+
},
|
244 |
+
)
|
245 |
+
logits, k_out, v_out = results[:3]
|
246 |
+
|
247 |
+
next_token = self.sample_logits(logits, "top_p", top_p, temperature)
|
248 |
+
next_token = next_token.reshape(1, -1)
|
249 |
+
|
250 |
+
# Stop if/when we get an ENDOFTEXT token before reaching maximum sequence length
|
251 |
+
if next_token[0] == tokenizer.eos_id:
|
252 |
+
del logits
|
253 |
+
gc.collect()
|
254 |
+
return
|
255 |
+
|
256 |
+
input_ids = torch.cat((input_ids, torch.tensor(next_token)), dim=-1)
|
257 |
+
|
258 |
+
generated_tokens.append(next_token[0].item())
|
259 |
+
text = tokenizer.decode(generated_tokens)
|
260 |
+
|
261 |
+
seq_len = x.shape[1]
|
262 |
+
self.k_cache[:, :, pos : pos + seq_len] = k_out
|
263 |
+
self.v_cache[:, :, pos : pos + seq_len] = v_out
|
264 |
+
pos = np.array(int(pos) + seq_len)
|
265 |
+
|
266 |
+
x = (
|
267 |
+
self.embeddingLayer(torch.tensor(next_token))
|
268 |
+
.unsqueeze(0)
|
269 |
+
.reshape([1, 1, self.hidden_size])
|
270 |
+
.cpu()
|
271 |
+
.detach()
|
272 |
+
.numpy()
|
273 |
+
.astype(self.data_type)
|
274 |
+
)
|
275 |
+
|
276 |
+
yield text
|
277 |
+
|
278 |
+
if any([x in text for x in stop_words]):
|
279 |
+
del logits
|
280 |
+
gc.collect()
|
281 |
+
return
|
282 |
+
|
283 |
+
def predict(
|
284 |
+
self,
|
285 |
+
text,
|
286 |
+
chatbot,
|
287 |
+
history,
|
288 |
+
top_p,
|
289 |
+
temperature,
|
290 |
+
max_length_tokens,
|
291 |
+
max_context_length_tokens,
|
292 |
+
):
|
293 |
+
if text == "":
|
294 |
+
yield chatbot, history, "Empty context."
|
295 |
+
return
|
296 |
+
try:
|
297 |
+
self.llm_session
|
298 |
+
except (ValueError, RuntimeError, TypeError):
|
299 |
+
yield [[text, "No Model Found"]], [], "No Model Found"
|
300 |
+
return
|
301 |
+
|
302 |
+
inputs = self.generate_prompt_with_history(
|
303 |
+
text, history, self.tokenizer, max_length=max_context_length_tokens
|
304 |
+
)
|
305 |
+
|
306 |
+
if inputs is None:
|
307 |
+
yield chatbot, history, "Input too long."
|
308 |
+
return
|
309 |
+
else:
|
310 |
+
prompt, inputs = inputs
|
311 |
+
|
312 |
+
input_ids = inputs[:, -max_context_length_tokens:]
|
313 |
+
|
314 |
+
# global total_count
|
315 |
+
self.total_count += 1
|
316 |
+
print(self.total_count)
|
317 |
+
|
318 |
+
self.head_dim = int(self.hidden_size / self.n_heads)
|
319 |
+
self.k_cache = np.zeros(
|
320 |
+
[1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim],
|
321 |
+
dtype=self.data_type,
|
322 |
+
)
|
323 |
+
self.v_cache = np.zeros(
|
324 |
+
[1, self.n_layers, self.max_seq_len, self.n_heads, self.head_dim],
|
325 |
+
dtype=self.data_type,
|
326 |
+
)
|
327 |
+
|
328 |
+
x = input_ids
|
329 |
+
|
330 |
+
for x in self.greedy_search(
|
331 |
+
input_ids,
|
332 |
+
self.llm_session,
|
333 |
+
self.tokenizer,
|
334 |
+
stop_words=["[|Human|]", "[|AI|]"],
|
335 |
+
max_length=max_length_tokens,
|
336 |
+
temperature=temperature,
|
337 |
+
top_p=top_p,
|
338 |
+
):
|
339 |
+
if is_stop_word_or_prefix(x, ["[|Human|]", "[|AI|]"]) is False:
|
340 |
+
if "[|Human|]" in x:
|
341 |
+
x = x[: x.index("[|Human|]")].strip()
|
342 |
+
if "[|AI|]" in x:
|
343 |
+
x = x[: x.index("[|AI|]")].strip()
|
344 |
+
x = x.strip()
|
345 |
+
a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [
|
346 |
+
[text, convert_to_markdown(x)]
|
347 |
+
], history + [[text, x]]
|
348 |
+
yield a, b, "Generating..."
|
349 |
+
if shared_state.interrupted:
|
350 |
+
shared_state.recover()
|
351 |
+
try:
|
352 |
+
yield a, b, "Stop: Success"
|
353 |
+
return
|
354 |
+
except Exception as e:
|
355 |
+
print(type(e).__name__, e)
|
356 |
+
pass
|
357 |
+
|
358 |
+
del input_ids
|
359 |
+
gc.collect()
|
360 |
+
torch.cuda.empty_cache()
|
361 |
+
|
362 |
+
try:
|
363 |
+
yield a, b, "Generate: Success"
|
364 |
+
except Exception as e:
|
365 |
+
print(type(e).__name__, e)
|
366 |
+
pass
|
367 |
+
|
368 |
+
return
|
369 |
+
|
370 |
+
def retry(
|
371 |
+
self,
|
372 |
+
text,
|
373 |
+
chatbot,
|
374 |
+
history,
|
375 |
+
top_p,
|
376 |
+
temperature,
|
377 |
+
max_length_tokens,
|
378 |
+
max_context_length_tokens,
|
379 |
+
):
|
380 |
+
logging.info("Retry...")
|
381 |
+
if len(history) == 0:
|
382 |
+
yield chatbot, history, "Empty context"
|
383 |
+
return
|
384 |
+
chatbot.pop()
|
385 |
+
inputs = history.pop()[0]
|
386 |
+
for x in self.predict(
|
387 |
+
inputs,
|
388 |
+
chatbot,
|
389 |
+
history,
|
390 |
+
top_p,
|
391 |
+
temperature,
|
392 |
+
max_length_tokens,
|
393 |
+
max_context_length_tokens,
|
394 |
+
):
|
395 |
+
yield x
|
ChatApp/requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
mdtex2html
|
3 |
+
pypinyin
|
4 |
+
tiktoken
|
5 |
+
socksio
|
6 |
+
tqdm
|
7 |
+
colorama
|
8 |
+
duckduckgo_search
|
9 |
+
Pygments
|
10 |
+
llama_index
|
11 |
+
langchain
|
12 |
+
markdown
|
13 |
+
markdown2
|
14 |
+
torch
|
15 |
+
git+https://github.com/huggingface/peft.git
|
16 |
+
git+https://github.com/huggingface/transformers.git
|
17 |
+
SentencePiece
|
18 |
+
onnxruntime-gpu
|