Update app.py
Browse files
app.py
CHANGED
@@ -19,7 +19,6 @@ from transformers import (
|
|
19 |
|
20 |
|
21 |
model_name = "WangZeJun/bloom-3b-moss-chat"
|
22 |
-
max_new_tokens = 1024
|
23 |
|
24 |
|
25 |
print(f"Starting to load the model {model_name} into memory")
|
@@ -43,14 +42,20 @@ class StopOnTokens(StoppingCriteria):
|
|
43 |
|
44 |
|
45 |
def convert_history_to_text(history):
|
46 |
-
|
47 |
user_input = history[-1][0]
|
48 |
-
|
49 |
input_pattern = "{}</s>"
|
50 |
text = input_pattern.format(user_input)
|
51 |
return text
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def log_conversation(conversation_id, history, messages, generate_kwargs):
|
56 |
logging_url = os.getenv("LOGGING_URL", None)
|
@@ -78,7 +83,7 @@ def user(message, history):
|
|
78 |
return "", history + [[message, ""]]
|
79 |
|
80 |
|
81 |
-
def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
|
82 |
print(f"history: {history}")
|
83 |
# Initialize a StopOnTokens object
|
84 |
stop = StopOnTokens()
|
@@ -136,6 +141,64 @@ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id)
|
|
136 |
history[-1][1] = partial_text
|
137 |
yield history
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
def get_uuid():
|
141 |
return str(uuid4())
|
@@ -162,7 +225,8 @@ with gr.Blocks(
|
|
162 |
).style(container=False)
|
163 |
with gr.Column():
|
164 |
with gr.Row():
|
165 |
-
|
|
|
166 |
stop = gr.Button("Stop")
|
167 |
clear = gr.Button("Clear")
|
168 |
with gr.Row():
|
@@ -172,18 +236,30 @@ with gr.Blocks(
|
|
172 |
with gr.Row():
|
173 |
temperature = gr.Slider(
|
174 |
label="Temperature",
|
175 |
-
value=0.
|
176 |
minimum=0.0,
|
177 |
maximum=1.0,
|
178 |
-
step=0.
|
179 |
interactive=True,
|
180 |
info="Higher values produce more diverse outputs",
|
181 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
with gr.Column():
|
183 |
with gr.Row():
|
184 |
top_p = gr.Slider(
|
185 |
label="Top-p (nucleus sampling)",
|
186 |
-
value=
|
187 |
minimum=0.0,
|
188 |
maximum=1,
|
189 |
step=0.01,
|
@@ -204,17 +280,16 @@ with gr.Blocks(
|
|
204 |
interactive=True,
|
205 |
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
|
206 |
)
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
)
|
218 |
# with gr.Row():
|
219 |
# gr.Markdown(
|
220 |
# "demo 2",
|
@@ -234,12 +309,13 @@ with gr.Blocks(
|
|
234 |
top_p,
|
235 |
top_k,
|
236 |
repetition_penalty,
|
|
|
237 |
conversation_id,
|
238 |
],
|
239 |
outputs=chatbot,
|
240 |
queue=True,
|
241 |
)
|
242 |
-
submit_click_event =
|
243 |
fn=user,
|
244 |
inputs=[msg, chatbot],
|
245 |
outputs=[msg, chatbot],
|
@@ -252,6 +328,26 @@ with gr.Blocks(
|
|
252 |
top_p,
|
253 |
top_k,
|
254 |
repetition_penalty,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
conversation_id,
|
256 |
],
|
257 |
outputs=chatbot,
|
|
|
19 |
|
20 |
|
21 |
model_name = "WangZeJun/bloom-3b-moss-chat"
|
|
|
22 |
|
23 |
|
24 |
print(f"Starting to load the model {model_name} into memory")
|
|
|
42 |
|
43 |
|
44 |
def convert_history_to_text(history):
|
|
|
45 |
user_input = history[-1][0]
|
|
|
46 |
input_pattern = "{}</s>"
|
47 |
text = input_pattern.format(user_input)
|
48 |
return text
|
49 |
|
50 |
+
def convert_all_history_to_text(history):
|
51 |
+
text = ""
|
52 |
+
for instance in history:
|
53 |
+
text += instance[0]
|
54 |
+
text += "</s>"
|
55 |
+
if instance[1]:
|
56 |
+
text += instance[1]
|
57 |
+
text += "</s>"
|
58 |
+
return text
|
59 |
|
60 |
def log_conversation(conversation_id, history, messages, generate_kwargs):
|
61 |
logging_url = os.getenv("LOGGING_URL", None)
|
|
|
83 |
return "", history + [[message, ""]]
|
84 |
|
85 |
|
86 |
+
def bot(history, temperature, top_p, top_k, repetition_penalty, max_new_tokens, conversation_id):
|
87 |
print(f"history: {history}")
|
88 |
# Initialize a StopOnTokens object
|
89 |
stop = StopOnTokens()
|
|
|
141 |
history[-1][1] = partial_text
|
142 |
yield history
|
143 |
|
144 |
+
def multi_bot(history, temperature, top_p, top_k, repetition_penalty, max_new_tokens, conversation_id):
|
145 |
+
print(f"history: {history}")
|
146 |
+
# Initialize a StopOnTokens object
|
147 |
+
stop = StopOnTokens()
|
148 |
+
|
149 |
+
# Construct the input message string for the model by concatenating the current system message and conversation history
|
150 |
+
messages = convert_all_history_to_text(history)
|
151 |
+
|
152 |
+
# Tokenize the messages string
|
153 |
+
input_ids = tok(messages, return_tensors="pt").input_ids
|
154 |
+
input_ids = input_ids.to(m.device)
|
155 |
+
streamer = TextIteratorStreamer(
|
156 |
+
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
157 |
+
generate_kwargs = dict(
|
158 |
+
input_ids=input_ids,
|
159 |
+
max_new_tokens=max_new_tokens,
|
160 |
+
temperature=temperature,
|
161 |
+
do_sample=temperature > 0.0,
|
162 |
+
top_p=top_p,
|
163 |
+
top_k=top_k,
|
164 |
+
repetition_penalty=repetition_penalty,
|
165 |
+
streamer=streamer,
|
166 |
+
stopping_criteria=StoppingCriteriaList([stop]),
|
167 |
+
)
|
168 |
+
|
169 |
+
stream_complete = Event()
|
170 |
+
|
171 |
+
def generate_and_signal_complete():
|
172 |
+
m.generate(**generate_kwargs)
|
173 |
+
stream_complete.set()
|
174 |
+
|
175 |
+
def log_after_stream_complete():
|
176 |
+
stream_complete.wait()
|
177 |
+
log_conversation(
|
178 |
+
conversation_id,
|
179 |
+
history,
|
180 |
+
messages,
|
181 |
+
{
|
182 |
+
"top_k": top_k,
|
183 |
+
"top_p": top_p,
|
184 |
+
"temperature": temperature,
|
185 |
+
"repetition_penalty": repetition_penalty,
|
186 |
+
},
|
187 |
+
)
|
188 |
+
|
189 |
+
t1 = Thread(target=generate_and_signal_complete)
|
190 |
+
t1.start()
|
191 |
+
|
192 |
+
t2 = Thread(target=log_after_stream_complete)
|
193 |
+
t2.start()
|
194 |
+
|
195 |
+
# Initialize an empty string to store the generated text
|
196 |
+
partial_text = ""
|
197 |
+
for new_text in streamer:
|
198 |
+
partial_text += new_text
|
199 |
+
history[-1][1] = partial_text
|
200 |
+
yield history
|
201 |
+
|
202 |
|
203 |
def get_uuid():
|
204 |
return str(uuid4())
|
|
|
225 |
).style(container=False)
|
226 |
with gr.Column():
|
227 |
with gr.Row():
|
228 |
+
single_submit = gr.Button("单轮对话")
|
229 |
+
multi_submit = gr.Button("多轮对话")
|
230 |
stop = gr.Button("Stop")
|
231 |
clear = gr.Button("Clear")
|
232 |
with gr.Row():
|
|
|
236 |
with gr.Row():
|
237 |
temperature = gr.Slider(
|
238 |
label="Temperature",
|
239 |
+
value=0.3,
|
240 |
minimum=0.0,
|
241 |
maximum=1.0,
|
242 |
+
step=0.05,
|
243 |
interactive=True,
|
244 |
info="Higher values produce more diverse outputs",
|
245 |
)
|
246 |
+
with gr.Column():
|
247 |
+
with gr.Row():
|
248 |
+
repetition_penalty = gr.Slider(
|
249 |
+
label="Repetition Penalty",
|
250 |
+
value=1.2,
|
251 |
+
minimum=1.0,
|
252 |
+
maximum=2.0,
|
253 |
+
step=0.05,
|
254 |
+
interactive=True,
|
255 |
+
info="Penalize repetition — 1.0 to disable.",
|
256 |
+
)
|
257 |
+
with gr.Row():
|
258 |
with gr.Column():
|
259 |
with gr.Row():
|
260 |
top_p = gr.Slider(
|
261 |
label="Top-p (nucleus sampling)",
|
262 |
+
value=0.85,
|
263 |
minimum=0.0,
|
264 |
maximum=1,
|
265 |
step=0.01,
|
|
|
280 |
interactive=True,
|
281 |
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
|
282 |
)
|
283 |
+
with gr.Row():
|
284 |
+
max_new_tokens = gr.Slider(
|
285 |
+
label="Maximum new tokens",
|
286 |
+
value=1024,
|
287 |
+
minimum=0,
|
288 |
+
maximum=2048,
|
289 |
+
step=1,
|
290 |
+
interactive=True,
|
291 |
+
)
|
292 |
+
|
|
|
293 |
# with gr.Row():
|
294 |
# gr.Markdown(
|
295 |
# "demo 2",
|
|
|
309 |
top_p,
|
310 |
top_k,
|
311 |
repetition_penalty,
|
312 |
+
max_new_tokens,
|
313 |
conversation_id,
|
314 |
],
|
315 |
outputs=chatbot,
|
316 |
queue=True,
|
317 |
)
|
318 |
+
submit_click_event = single_submit.click(
|
319 |
fn=user,
|
320 |
inputs=[msg, chatbot],
|
321 |
outputs=[msg, chatbot],
|
|
|
328 |
top_p,
|
329 |
top_k,
|
330 |
repetition_penalty,
|
331 |
+
max_new_tokens,
|
332 |
+
conversation_id,
|
333 |
+
],
|
334 |
+
outputs=chatbot,
|
335 |
+
queue=True,
|
336 |
+
)
|
337 |
+
multi_click_event = multi_submit.click(
|
338 |
+
fn=user,
|
339 |
+
inputs=[msg, chatbot],
|
340 |
+
outputs=[msg, chatbot],
|
341 |
+
queue=False,
|
342 |
+
).then(
|
343 |
+
fn=multi_bot,
|
344 |
+
inputs=[
|
345 |
+
chatbot,
|
346 |
+
temperature,
|
347 |
+
top_p,
|
348 |
+
top_k,
|
349 |
+
repetition_penalty,
|
350 |
+
max_new_tokens,
|
351 |
conversation_id,
|
352 |
],
|
353 |
outputs=chatbot,
|