wenjiao's picture
Update app.py
454f02d
import argparse
from collections import defaultdict
import datetime
import json
import os
import time
import uuid
import gradio as gr
import requests
from fastchat.conversation import (
Conversation,
compute_skip_echo_len,
SeparatorStyle,
)
from fastchat.constants import LOGDIR
from fastchat.utils import (
build_logger,
server_error_msg,
violates_moderation,
moderation_msg,
)
from fastchat.serve.gradio_patch import Chatbot as grChatbot
from fastchat.serve.gradio_css import code_highlight_css
logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "NeuralChat Client"}
no_change_btn = gr.Button.update()
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
controller_url = None
enable_moderation = False
conv_template_bf16 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="\n",
sep2="</s>",
)
def set_global_vars(controller_url_, enable_moderation_):
global controller_url, enable_moderation
controller_url = controller_url_
enable_moderation = enable_moderation_
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_model_list(controller_url):
ret = requests.post(controller_url + "/refresh_all_workers")
assert ret.status_code == 200
ret = requests.post(controller_url + "/list_models")
models = ret.json()["models"]
logger.info(f"Models: {models}")
return models
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log("url_params", url_params);
return url_params;
}
"""
def load_demo_single(models, url_params):
dropdown_update = gr.Dropdown.update(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = None
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=False),
)
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
return load_demo_single(models, url_params)
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, model_selector, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
def regenerate(state, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
def add_text(state, text, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if state is None:
state = conv_template_bf16.copy()
if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
if enable_moderation:
flagged = violates_moderation(text)
if flagged:
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg) + (
no_change_btn,
) * 5
text = text[:1536] # Hard cut-off
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
temperature = float(temperature)
max_new_tokens = int(max_new_tokens)
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
new_state = conv_template_bf16.copy()
new_state.conv_id = uuid.uuid4().hex
new_state.model_name = state.model_name or model_selector
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# Query worker address
ret = requests.post(
controller_url + "/get_worker_address", json={"model": model_name}
)
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
state.messages[-1][-1] = server_error_msg
yield (
state,
state.to_gradio_chatbot(),
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
# Construct prompt
prompt = state.get_prompt()
skip_echo_len = compute_skip_echo_len(model_name, state, prompt)
# Make requests
pload = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"stop": "</s>"
}
logger.info(f"==== request ====\n{pload}")
start_time = time.time()
state.messages[-1][-1] = "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(
worker_addr + "/worker_generate_stream",
headers=headers,
json=pload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
time.sleep(0.02)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
finish_tstamp = time.time() - start_time
elapsed_time = "\n✅generation elapsed time: {}s".format(round(finish_tstamp, 4))
state.messages[-1][-1] = state.messages[-1][-1][:-1] + elapsed_time
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"gen_params": {
"temperature": temperature,
"max_new_tokens": max_new_tokens,
},
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = (
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
#notice_markdown th {
display: none;
}
#notice_markdown {
text-align: center;
background: #874bec;
padding: 1%;
}
#notice_markdown h1, #notice_markdown h4 {
color: #fff;
margin-top: 0;
}
gradio-app {
background: linear-gradient(to bottom, #ba97d8, #5400ff) !important;
padding: 3%;
}
.gradio-container {
margin: 0 auto !important;
width: 70% !important;
padding: 0 !important;
}
#chatbot {
border-style: solid;
overflow: visible;
margin: 1% 4%;
width: 90%;
box-shadow: 0 15px 15px -5px rgba(0, 0, 0, 0.2);
border: 1px solid #ddd;
}
#text-box-style, #btn-style {
width: 90%;
margin: 1% 4%;
}
.user, .bot {
width: 80% !important;
}
.bot {
white-space: pre-wrap !important;
line-height: 1.3 !important;
display: flex;
flex-direction: column;
justify-content: flex-start;
}
#btn-send-style {
background: rgb(0, 180, 50);
color: #fff;
}
#btn-list-style {
background: #eee0;
border: 1px solid #691ef7;
}
"""
)
def build_single_model_ui(models):
notice_markdown = """
# 🤖 NeuralChat
#### deployed on 4th Gen Intel Xeon Scalable Processors codenamed Sapphire Rapids
"""
learn_more_markdown = """
"""
state = gr.State()
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Row(elem_id="model_selector_row", visible=False):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
).style(container=False)
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row(elem_id="text-box-style"):
with gr.Column(scale=20):
textbox = gr.Textbox(
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
).style(container=False)
with gr.Column(scale=1, min_width=50):
send_btn = gr.Button(value="Send", visible=False, elem_id="btn-send-style")
with gr.Row(visible=False, elem_id="btn-style") as button_row:
upvote_btn = gr.Button(value="👍 Upvote", interactive=False, elem_id="btn-list-style")
downvote_btn = gr.Button(value="👎 Downvote", interactive=False, elem_id="btn-list-style")
flag_btn = gr.Button(value="⚠️ Flag", interactive=False, visible=False, elem_id="btn-list-style")
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, elem_id="btn-list-style")
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False, elem_id="btn-list-style")
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.1,
interactive=True,
label="Temperature",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="Max output tokens",
)
gr.Markdown(learn_more_markdown)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
)
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, model_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, model_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
send_btn.click(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, model_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
def build_demo(models):
with gr.Blocks(
title="NeuralChat · Intel",
theme=gr.themes.Base(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
state,
model_selector,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
) = build_single_model_ui(models)
if model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[
state,
model_selector,
chatbot,
textbox,
send_btn,
button_row,
parameter_row,
],
_js=get_window_url_params,
)
else:
raise ValueError(f"Unknown model list mode: {model_list_mode}")
return demo
if __name__ == "__main__":
controller_url = "http://mlp-dgx-01.sh.intel.com:21001"
host = "mlp-dgx-01.sh.intel.com"
# port = "mlp-dgx-01.sh.intel.com"
concurrency_count = 10
model_list_mode = "once"
share = True
moderate = False
set_global_vars(controller_url, moderate)
models = get_model_list(controller_url)
demo = build_demo(models)
demo.queue(
concurrency_count=concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=host, share=share, max_threads=200
)