import argparse from collections import defaultdict import datetime import json import os import time import uuid os.system("pip install --upgrade gradio") import gradio as gr import requests from fastchat.conversation import ( Conversation, compute_skip_echo_len, SeparatorStyle, ) from fastchat.constants import LOGDIR from fastchat.utils import ( build_logger, server_error_msg, violates_moderation, moderation_msg, ) from fastchat.serve.gradio_patch import Chatbot as grChatbot from fastchat.serve.gradio_css import code_highlight_css logger = build_logger("gradio_web_server", "gradio_web_server.log") headers = {"User-Agent": "NeuralChat Client"} no_change_btn = gr.Button.update() enable_btn = gr.Button.update(interactive=True) disable_btn = gr.Button.update(interactive=False) controller_url = None enable_moderation = False conv_template_bf16 = Conversation( system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=(), offset=0, sep_style=SeparatorStyle.SINGLE, sep="\n", sep2="", ) def set_global_vars(controller_url_, enable_moderation_): global controller_url, enable_moderation controller_url = controller_url_ enable_moderation = enable_moderation_ def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") return name def get_model_list(controller_url): ret = requests.post(controller_url + "/refresh_all_workers") assert ret.status_code == 200 ret = requests.post(controller_url + "/list_models") models = ret.json()["models"] logger.info(f"Models: {models}") return models get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log("url_params", url_params); return url_params; } """ def load_demo_single(models, url_params): dropdown_update = gr.Dropdown.update(visible=True) if "model" in url_params: model = url_params["model"] if model in models: dropdown_update = gr.Dropdown.update(value=model, visible=True) state = None return ( state, dropdown_update, gr.Chatbot.update(visible=True), gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Row.update(visible=True), gr.Accordion.update(visible=True), ) def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") return load_demo_single(models, url_params) def vote_last_response(state, vote_type, model_selector, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector, "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") def upvote_last_response(state, model_selector, request: gr.Request): logger.info(f"upvote. ip: {request.client.host}") vote_last_response(state, "upvote", model_selector, request) return ("",) + (disable_btn,) * 3 def downvote_last_response(state, model_selector, request: gr.Request): logger.info(f"downvote. ip: {request.client.host}") vote_last_response(state, "downvote", model_selector, request) return ("",) + (disable_btn,) * 3 def flag_last_response(state, model_selector, request: gr.Request): logger.info(f"flag. ip: {request.client.host}") vote_last_response(state, "flag", model_selector, request) return ("",) + (disable_btn,) * 3 def regenerate(state, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") state.messages[-1][-1] = None state.skip_next = False return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = None return (state, [], "") + (disable_btn,) * 5 def add_text(state, text, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if state is None: state = conv_template_bf16.copy() if len(text) <= 0: state.skip_next = True return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 if enable_moderation: flagged = violates_moderation(text) if flagged: logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") state.skip_next = True return (state, state.to_gradio_chatbot(), moderation_msg) + ( no_change_btn, ) * 5 text = text[:1536] # Hard cut-off state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 def post_process_code(code): sep = "\n```" if sep in code: blocks = code.split(sep) if len(blocks) % 2 == 1: for i in range(1, len(blocks), 2): blocks[i] = blocks[i].replace("\\_", "_") code = sep.join(blocks) return code def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() model_name = model_selector temperature = float(temperature) max_new_tokens = int(max_new_tokens) if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return if len(state.messages) == state.offset + 2: # First round of conversation new_state = conv_template_bf16.copy() new_state.conv_id = uuid.uuid4().hex new_state.model_name = state.model_name or model_selector new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state # Query worker address ret = requests.post( controller_url + "/get_worker_address", json={"model": model_name} ) worker_addr = ret.json()["address"] logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") # No available worker if worker_addr == "": state.messages[-1][-1] = server_error_msg yield ( state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return # Construct prompt prompt = state.get_prompt() skip_echo_len = compute_skip_echo_len(model_name, state, prompt) # Make requests pload = { "model": model_name, "prompt": prompt, "temperature": temperature, "max_new_tokens": max_new_tokens, "stop": "" } logger.info(f"==== request ====\n{pload}") start_time = time.time() state.messages[-1][-1] = "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: # Stream output response = requests.post( controller_url + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=20, ) for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: output = data["text"][skip_echo_len:].strip() output = post_process_code(output) state.messages[-1][-1] = output + "▌" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + ( disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return time.sleep(0.005) except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" yield (state, state.to_gradio_chatbot()) + ( disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, ) return finish_tstamp = time.time() - start_time elapsed_time = "\n✅generation elapsed time: {}s".format(round(finish_tstamp, 4)) # elapsed_time = "\n{}s".format(round(finish_tstamp, 4)) # elapsed_time = "
{}s
".format(round(finish_tstamp, 4)) state.messages[-1][-1] = state.messages[-1][-1][:-1] + elapsed_time yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 logger.info(f"{output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "gen_params": { "temperature": temperature, "max_new_tokens": max_new_tokens, }, "start": round(start_tstamp, 4), "finish": round(start_tstamp, 4), "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") block_css = ( code_highlight_css + """ pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ white-space: -pre-wrap; /* Opera 4-6 */ white-space: -o-pre-wrap; /* Opera 7 */ word-wrap: break-word; /* Internet Explorer 5.5+ */ } #notice_markdown th { display: none; } #notice_markdown { text-align: center; background: #874bec; padding: 1%; height: 4.3rem; color: #fff !important; margin-top: 0; } #notice_markdown p{ color: #fff !important; } #notice_markdown h1, #notice_markdown h4 { color: #fff; margin-top: 0; } gradio-app { background: linear-gradient(to bottom, #ba97d8, #5400ff) !important; padding: 3%; } .gradio-container { margin: 0 auto !important; width: 70% !important; padding: 0 !important; background: #fff !important; border-radius: 5px !important; } #chatbot { border-style: solid; overflow: visible; margin: 1% 4%; width: 90%; box-shadow: 0 15px 15px -5px rgba(0, 0, 0, 0.2); border: 1px solid #ddd; } #text-box-style, #btn-style { width: 90%; margin: 1% 4%; } .user, .bot { width: 80% !important; } .bot { white-space: pre-wrap !important; line-height: 1.3 !important; display: flex; flex-direction: column; justify-content: flex-start; } #btn-send-style { background: rgb(0, 180, 50); color: #fff; } #btn-list-style { background: #eee0; border: 1px solid #691ef7; } .title { font-size: 1.5rem; font-weight: 700; color: #fff !important; } footer { display: none !important; } .footer { margin-bottom: 45px; margin-top: 35px; text-align: center; border-bottom: 1px solid #e5e5e5; } .footer>p { font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; } .acknowledgments { width: 80%; margin: 0 auto; } .img-logo-style { width: 3.5rem; float: left; } """ ) def build_single_model_ui(models): notice_markdown = """deployed on 4th Gen Intel Xeon Scalable Processors codenamed Sapphire Rapids.
""" learn_more_markdown = """