Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,184 +1,551 @@
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
-
import gradio as gr
|
3 |
-
import numpy as np
|
4 |
-
import random
|
5 |
-
import torch
|
6 |
-
import subprocess
|
7 |
import time
|
|
|
|
|
|
|
8 |
import requests
|
9 |
-
import json
|
10 |
|
11 |
-
import
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
from huggingface_hub.utils import (
|
16 |
-
HfFolder
|
17 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
myip = os.environ["myip"]
|
20 |
-
myport = os.environ["myport"]
|
21 |
-
|
22 |
-
url = f"http://{myip}:{myport}"
|
23 |
|
24 |
-
|
25 |
|
26 |
-
|
27 |
-
global queue_size
|
28 |
-
if queue_size > 4:
|
29 |
-
return [gr.update(visible=False), gr.update(visible=True)]
|
30 |
-
elif queue_size <= 4:
|
31 |
-
return [gr.update(visible=True), gr.update(visible=False)]
|
32 |
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
que_high_msg = "The current traffic is high with " + str(queue_size) + " in the queue. Please wait a moment."
|
37 |
-
que_normal_msg = "The current traffic is not high. You can submit your job now."
|
38 |
-
|
39 |
-
if queue_size > int(os.environ["max_queue_size"]):
|
40 |
-
return que_high_msg
|
41 |
-
else:
|
42 |
-
return que_normal_msg
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
def img2img_generate(source_img, prompt, steps=25, strength=0.75, seed=42, guidance_scale=7.5):
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
source_img.save(buffered, format="JPEG")
|
52 |
-
img_b64 = base64.b64encode(buffered.getvalue())
|
53 |
-
timestamp = int(time.time()*1000)
|
54 |
|
55 |
-
data = {"source_img": img_b64.decode(), "prompt": prompt, "steps": steps,
|
56 |
-
"guidance_scale": guidance_scale, "seed": seed, "strength": strength,
|
57 |
-
"task_type": "1",
|
58 |
-
"timestamp": timestamp, "user": os.environ.get("token", "")}
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
queue_size = queue_size - 1
|
65 |
|
66 |
-
try:
|
67 |
-
img_str = json.loads(resp.text)["img_str"]
|
68 |
-
print("Compute node: ", json.loads(resp.text)["ip"])
|
69 |
-
except:
|
70 |
-
print('No inference result. Please check server connection')
|
71 |
-
return None
|
72 |
-
|
73 |
-
img_byte = base64.b64decode(img_str)
|
74 |
-
img_io = BytesIO(img_byte) # convert image to file-like object
|
75 |
-
img = Image.open(img_io) # img is now PIL Image object
|
76 |
-
print("elapsed time: ", time.time() - start_time)
|
77 |
-
return img
|
78 |
-
|
79 |
-
|
80 |
-
def txt2img_generate(prompt, steps=25, seed=42, guidance_scale=7.5):
|
81 |
-
|
82 |
-
print('text-to-image')
|
83 |
-
print("prompt: ", prompt)
|
84 |
-
print("steps: ", steps)
|
85 |
-
timestamp = int(time.time()*1000)
|
86 |
-
data = {"prompt": prompt,
|
87 |
-
"steps": steps, "guidance_scale": guidance_scale, "seed": seed,
|
88 |
-
"task_type": "0",
|
89 |
-
"timestamp": timestamp, "user": os.environ.get("token", "")}
|
90 |
-
start_time = time.time()
|
91 |
-
global queue_size
|
92 |
-
queue_size = queue_size + 1
|
93 |
-
resp = requests.post(url, data=json.dumps(data))
|
94 |
-
queue_size = queue_size - 1
|
95 |
-
try:
|
96 |
-
img_str = json.loads(resp.text)["img_str"]
|
97 |
-
print("Compute node: ", json.loads(resp.text)["ip"])
|
98 |
-
except:
|
99 |
-
print('No inference result. Please check server connection')
|
100 |
-
return None
|
101 |
-
|
102 |
-
img_byte = base64.b64decode(img_str)
|
103 |
-
img_io = BytesIO(img_byte) # convert image to file-like object
|
104 |
-
img = Image.open(img_io) # img is now PIL Image object
|
105 |
-
print("elapsed time: ", time.time() - start_time)
|
106 |
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
113 |
"""
|
114 |
|
115 |
-
legal = """
|
116 |
-
Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex. Performance results are based on testing as of dates shown in configurations and may not reflect all publicly available updates. See backup for configuration details. No product or component can be absolutely secure.
|
117 |
-
© Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.
|
118 |
-
"""
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
""
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
with
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
|
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
with gr.Column():
|
161 |
-
source_img = gr.Image(source="upload", type="pil", value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg")
|
162 |
-
# source_img = gr.Image(source="upload", type="pil")
|
163 |
-
prompt_2 = gr.inputs.Textbox(label='Prompt', default='A fantasy landscape, trending on artstation')
|
164 |
-
inference_steps_2 = gr.inputs.Slider(1, 100, label='Inference Steps - increase the steps for better quality (e.g., avoiding black image) ', default=20, step=1)
|
165 |
-
seed_2 = gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1)
|
166 |
-
guidance_scale_2 = gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=7.5, step=0.1)
|
167 |
-
strength = gr.inputs.Slider(0.0, 1.0, label='Strength - adding more noise to it the larger the strength', default=0.75, step=0.01)
|
168 |
-
img2img_button = gr.Button("Generate Image")
|
169 |
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
txt2img_button.click(fn=txt2img_generate, inputs=[prompt, inference_steps, seed, guidance_scale], outputs=[result_image])
|
174 |
|
175 |
-
|
|
|
|
|
176 |
|
177 |
-
|
178 |
-
gr.Markdown(details, elem_id='mdStyle')
|
179 |
|
180 |
-
|
181 |
-
gr.Markdown(legal, elem_id='mdStyle')
|
182 |
|
183 |
-
|
|
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from collections import defaultdict
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
import os
|
|
|
|
|
|
|
|
|
|
|
6 |
import time
|
7 |
+
import uuid
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
import requests
|
|
|
11 |
|
12 |
+
from fastchat.conversation import (
|
13 |
+
Conversation,
|
14 |
+
compute_skip_echo_len,
|
15 |
+
SeparatorStyle,
|
|
|
|
|
16 |
)
|
17 |
+
from fastchat.constants import LOGDIR
|
18 |
+
from fastchat.utils import (
|
19 |
+
build_logger,
|
20 |
+
server_error_msg,
|
21 |
+
violates_moderation,
|
22 |
+
moderation_msg,
|
23 |
+
)
|
24 |
+
from fastchat.serve.gradio_patch import Chatbot as grChatbot
|
25 |
+
from fastchat.serve.gradio_css import code_highlight_css
|
26 |
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
29 |
|
30 |
+
headers = {"User-Agent": "NeuralChat Client"}
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
no_change_btn = gr.Button.update()
|
33 |
+
enable_btn = gr.Button.update(interactive=True)
|
34 |
+
disable_btn = gr.Button.update(interactive=False)
|
35 |
|
36 |
+
controller_url = None
|
37 |
+
enable_moderation = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
conv_template_bf16 = Conversation(
|
40 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
41 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
42 |
+
roles=("Human", "Assistant"),
|
43 |
+
messages=(),
|
44 |
+
offset=0,
|
45 |
+
sep_style=SeparatorStyle.SINGLE,
|
46 |
+
sep="\n",
|
47 |
+
sep2="</s>",
|
48 |
+
)
|
49 |
|
|
|
50 |
|
51 |
+
def set_global_vars(controller_url_, enable_moderation_):
|
52 |
+
global controller_url, enable_moderation
|
53 |
+
controller_url = controller_url_
|
54 |
+
enable_moderation = enable_moderation_
|
|
|
|
|
|
|
55 |
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
def get_conv_log_filename():
|
58 |
+
t = datetime.datetime.now()
|
59 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
60 |
+
return name
|
|
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
def get_model_list(controller_url):
|
64 |
+
ret = requests.post(controller_url + "/refresh_all_workers")
|
65 |
+
assert ret.status_code == 200
|
66 |
+
ret = requests.post(controller_url + "/list_models")
|
67 |
+
models = ret.json()["models"]
|
68 |
+
logger.info(f"Models: {models}")
|
69 |
+
return models
|
70 |
|
71 |
|
72 |
+
get_window_url_params = """
|
73 |
+
function() {
|
74 |
+
const params = new URLSearchParams(window.location.search);
|
75 |
+
url_params = Object.fromEntries(params);
|
76 |
+
console.log("url_params", url_params);
|
77 |
+
return url_params;
|
78 |
+
}
|
79 |
"""
|
80 |
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
def load_demo_single(models, url_params):
|
83 |
+
dropdown_update = gr.Dropdown.update(visible=True)
|
84 |
+
if "model" in url_params:
|
85 |
+
model = url_params["model"]
|
86 |
+
if model in models:
|
87 |
+
dropdown_update = gr.Dropdown.update(value=model, visible=True)
|
88 |
+
|
89 |
+
state = None
|
90 |
+
return (
|
91 |
+
state,
|
92 |
+
dropdown_update,
|
93 |
+
gr.Chatbot.update(visible=True),
|
94 |
+
gr.Textbox.update(visible=True),
|
95 |
+
gr.Button.update(visible=True),
|
96 |
+
gr.Row.update(visible=True),
|
97 |
+
gr.Accordion.update(visible=False),
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
def load_demo(url_params, request: gr.Request):
|
102 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
103 |
+
return load_demo_single(models, url_params)
|
104 |
+
|
105 |
+
|
106 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
107 |
+
with open(get_conv_log_filename(), "a") as fout:
|
108 |
+
data = {
|
109 |
+
"tstamp": round(time.time(), 4),
|
110 |
+
"type": vote_type,
|
111 |
+
"model": model_selector,
|
112 |
+
"state": state.dict(),
|
113 |
+
"ip": request.client.host,
|
114 |
+
}
|
115 |
+
fout.write(json.dumps(data) + "\n")
|
116 |
+
|
117 |
+
|
118 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
119 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
120 |
+
vote_last_response(state, "upvote", model_selector, request)
|
121 |
+
return ("",) + (disable_btn,) * 3
|
122 |
+
|
123 |
+
|
124 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
125 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
126 |
+
vote_last_response(state, "downvote", model_selector, request)
|
127 |
+
return ("",) + (disable_btn,) * 3
|
128 |
+
|
129 |
+
|
130 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
131 |
+
logger.info(f"flag. ip: {request.client.host}")
|
132 |
+
vote_last_response(state, "flag", model_selector, request)
|
133 |
+
return ("",) + (disable_btn,) * 3
|
134 |
+
|
135 |
+
|
136 |
+
def regenerate(state, request: gr.Request):
|
137 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
138 |
+
state.messages[-1][-1] = None
|
139 |
+
state.skip_next = False
|
140 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
141 |
+
|
142 |
+
|
143 |
+
def clear_history(request: gr.Request):
|
144 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
145 |
+
state = None
|
146 |
+
return (state, [], "") + (disable_btn,) * 5
|
147 |
+
|
148 |
+
|
149 |
+
def add_text(state, text, request: gr.Request):
|
150 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
151 |
+
|
152 |
+
if state is None:
|
153 |
+
state = conv_template_bf16.copy()
|
154 |
+
|
155 |
+
if len(text) <= 0:
|
156 |
+
state.skip_next = True
|
157 |
+
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
158 |
+
if enable_moderation:
|
159 |
+
flagged = violates_moderation(text)
|
160 |
+
if flagged:
|
161 |
+
logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
|
162 |
+
state.skip_next = True
|
163 |
+
return (state, state.to_gradio_chatbot(), moderation_msg) + (
|
164 |
+
no_change_btn,
|
165 |
+
) * 5
|
166 |
+
|
167 |
+
text = text[:1536] # Hard cut-off
|
168 |
+
state.append_message(state.roles[0], text)
|
169 |
+
state.append_message(state.roles[1], None)
|
170 |
+
state.skip_next = False
|
171 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
172 |
+
|
173 |
+
|
174 |
+
def post_process_code(code):
|
175 |
+
sep = "\n```"
|
176 |
+
if sep in code:
|
177 |
+
blocks = code.split(sep)
|
178 |
+
if len(blocks) % 2 == 1:
|
179 |
+
for i in range(1, len(blocks), 2):
|
180 |
+
blocks[i] = blocks[i].replace("\\_", "_")
|
181 |
+
code = sep.join(blocks)
|
182 |
+
return code
|
183 |
+
|
184 |
+
|
185 |
+
def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
|
186 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
187 |
+
start_tstamp = time.time()
|
188 |
+
model_name = model_selector
|
189 |
+
temperature = float(temperature)
|
190 |
+
max_new_tokens = int(max_new_tokens)
|
191 |
+
|
192 |
+
if state.skip_next:
|
193 |
+
# This generate call is skipped due to invalid inputs
|
194 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
195 |
+
return
|
196 |
+
|
197 |
+
if len(state.messages) == state.offset + 2:
|
198 |
+
# First round of conversation
|
199 |
+
new_state = conv_template_bf16.copy()
|
200 |
+
new_state.conv_id = uuid.uuid4().hex
|
201 |
+
new_state.model_name = state.model_name or model_selector
|
202 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
203 |
+
new_state.append_message(new_state.roles[1], None)
|
204 |
+
state = new_state
|
205 |
+
|
206 |
+
# Query worker address
|
207 |
+
ret = requests.post(
|
208 |
+
controller_url + "/get_worker_address", json={"model": model_name}
|
209 |
+
)
|
210 |
+
worker_addr = ret.json()["address"]
|
211 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
212 |
+
|
213 |
+
# No available worker
|
214 |
+
if worker_addr == "":
|
215 |
+
state.messages[-1][-1] = server_error_msg
|
216 |
+
yield (
|
217 |
+
state,
|
218 |
+
state.to_gradio_chatbot(),
|
219 |
+
disable_btn,
|
220 |
+
disable_btn,
|
221 |
+
disable_btn,
|
222 |
+
enable_btn,
|
223 |
+
enable_btn,
|
224 |
+
)
|
225 |
+
return
|
226 |
+
|
227 |
+
# Construct prompt
|
228 |
+
prompt = state.get_prompt()
|
229 |
+
skip_echo_len = compute_skip_echo_len(model_name, state, prompt)
|
230 |
+
|
231 |
+
# Make requests
|
232 |
+
pload = {
|
233 |
+
"model": model_name,
|
234 |
+
"prompt": prompt,
|
235 |
+
"temperature": temperature,
|
236 |
+
"max_new_tokens": max_new_tokens,
|
237 |
+
"stop": "</s>"
|
238 |
+
}
|
239 |
+
logger.info(f"==== request ====\n{pload}")
|
240 |
|
241 |
+
start_time = time.time()
|
242 |
|
243 |
+
state.messages[-1][-1] = "▌"
|
244 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
+
try:
|
247 |
+
# Stream output
|
248 |
+
response = requests.post(
|
249 |
+
worker_addr + "/worker_generate_stream",
|
250 |
+
headers=headers,
|
251 |
+
json=pload,
|
252 |
+
stream=True,
|
253 |
+
timeout=20,
|
254 |
+
)
|
255 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
256 |
+
if chunk:
|
257 |
+
data = json.loads(chunk.decode())
|
258 |
+
if data["error_code"] == 0:
|
259 |
+
output = data["text"][skip_echo_len:].strip()
|
260 |
+
output = post_process_code(output)
|
261 |
+
state.messages[-1][-1] = output + "▌"
|
262 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
263 |
+
else:
|
264 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
265 |
+
state.messages[-1][-1] = output
|
266 |
+
yield (state, state.to_gradio_chatbot()) + (
|
267 |
+
disable_btn,
|
268 |
+
disable_btn,
|
269 |
+
disable_btn,
|
270 |
+
enable_btn,
|
271 |
+
enable_btn,
|
272 |
+
)
|
273 |
+
return
|
274 |
+
time.sleep(0.02)
|
275 |
+
except requests.exceptions.RequestException as e:
|
276 |
+
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
277 |
+
yield (state, state.to_gradio_chatbot()) + (
|
278 |
+
disable_btn,
|
279 |
+
disable_btn,
|
280 |
+
disable_btn,
|
281 |
+
enable_btn,
|
282 |
+
enable_btn,
|
283 |
+
)
|
284 |
+
return
|
285 |
+
|
286 |
+
finish_tstamp = time.time() - start_time
|
287 |
+
elapsed_time = "\n✅generation elapsed time: {}s".format(round(finish_tstamp, 4))
|
288 |
+
|
289 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1] + elapsed_time
|
290 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
291 |
+
|
292 |
+
logger.info(f"{output}")
|
293 |
+
|
294 |
+
with open(get_conv_log_filename(), "a") as fout:
|
295 |
+
data = {
|
296 |
+
"tstamp": round(finish_tstamp, 4),
|
297 |
+
"type": "chat",
|
298 |
+
"model": model_name,
|
299 |
+
"gen_params": {
|
300 |
+
"temperature": temperature,
|
301 |
+
"max_new_tokens": max_new_tokens,
|
302 |
+
},
|
303 |
+
"start": round(start_tstamp, 4),
|
304 |
+
"finish": round(start_tstamp, 4),
|
305 |
+
"state": state.dict(),
|
306 |
+
"ip": request.client.host,
|
307 |
+
}
|
308 |
+
fout.write(json.dumps(data) + "\n")
|
309 |
+
|
310 |
+
|
311 |
+
block_css = (
|
312 |
+
code_highlight_css
|
313 |
+
+ """
|
314 |
+
pre {
|
315 |
+
white-space: pre-wrap; /* Since CSS 2.1 */
|
316 |
+
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
317 |
+
white-space: -pre-wrap; /* Opera 4-6 */
|
318 |
+
white-space: -o-pre-wrap; /* Opera 7 */
|
319 |
+
word-wrap: break-word; /* Internet Explorer 5.5+ */
|
320 |
+
}
|
321 |
+
#notice_markdown th {
|
322 |
+
display: none;
|
323 |
+
}
|
324 |
+
|
325 |
+
#notice_markdown {
|
326 |
+
text-align: center;
|
327 |
+
background: #874bec;
|
328 |
+
padding: 1%;
|
329 |
+
}
|
330 |
+
|
331 |
+
#notice_markdown h1, #notice_markdown h4 {
|
332 |
+
color: #fff;
|
333 |
+
margin-top: 0;
|
334 |
+
}
|
335 |
+
|
336 |
+
gradio-app {
|
337 |
+
background: linear-gradient(to bottom, #ba97d8, #5400ff) !important;
|
338 |
+
padding: 3%;
|
339 |
+
}
|
340 |
+
|
341 |
+
.gradio-container {
|
342 |
+
margin: 0 auto !important;
|
343 |
+
width: 70% !important;
|
344 |
+
padding: 0 !important;
|
345 |
+
}
|
346 |
+
|
347 |
+
#chatbot {
|
348 |
+
border-style: solid;
|
349 |
+
overflow: visible;
|
350 |
+
margin: 1% 4%;
|
351 |
+
width: 90%;
|
352 |
+
box-shadow: 0 15px 15px -5px rgba(0, 0, 0, 0.2);
|
353 |
+
border: 1px solid #ddd;
|
354 |
+
}
|
355 |
+
|
356 |
+
#text-box-style, #btn-style {
|
357 |
+
width: 90%;
|
358 |
+
margin: 1% 4%;
|
359 |
+
}
|
360 |
+
|
361 |
+
|
362 |
+
.user, .bot {
|
363 |
+
width: 80% !important;
|
364 |
+
|
365 |
+
}
|
366 |
+
|
367 |
+
.bot {
|
368 |
+
white-space: pre-wrap !important;
|
369 |
+
line-height: 1.3 !important;
|
370 |
+
display: flex;
|
371 |
+
flex-direction: column;
|
372 |
+
justify-content: flex-start;
|
373 |
+
|
374 |
+
}
|
375 |
+
|
376 |
+
#btn-send-style {
|
377 |
+
background: rgb(0, 180, 50);
|
378 |
+
color: #fff;
|
379 |
+
}
|
380 |
+
|
381 |
+
#btn-list-style {
|
382 |
+
background: #eee0;
|
383 |
+
border: 1px solid #691ef7;
|
384 |
+
}
|
385 |
+
"""
|
386 |
+
)
|
387 |
|
|
|
388 |
|
389 |
+
def build_single_model_ui(models):
|
390 |
+
notice_markdown = """
|
391 |
+
# 🤖 NeuralChat
|
392 |
|
393 |
+
#### deployed on 4th Gen Intel Xeon Scalable Processors codenamed Sapphire Rapids
|
|
|
394 |
|
395 |
+
"""
|
|
|
396 |
|
397 |
+
learn_more_markdown = """
|
398 |
+
"""
|
399 |
|
400 |
+
state = gr.State()
|
401 |
+
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
402 |
+
|
403 |
+
with gr.Row(elem_id="model_selector_row", visible=False):
|
404 |
+
model_selector = gr.Dropdown(
|
405 |
+
choices=models,
|
406 |
+
value=models[0] if len(models) > 0 else "",
|
407 |
+
interactive=True,
|
408 |
+
show_label=False,
|
409 |
+
).style(container=False)
|
410 |
+
|
411 |
+
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
412 |
+
with gr.Row(elem_id="text-box-style"):
|
413 |
+
with gr.Column(scale=20):
|
414 |
+
textbox = gr.Textbox(
|
415 |
+
show_label=False,
|
416 |
+
placeholder="Enter text and press ENTER",
|
417 |
+
visible=False,
|
418 |
+
).style(container=False)
|
419 |
+
with gr.Column(scale=1, min_width=50):
|
420 |
+
send_btn = gr.Button(value="Send", visible=False, elem_id="btn-send-style")
|
421 |
+
|
422 |
+
with gr.Row(visible=False, elem_id="btn-style") as button_row:
|
423 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False, elem_id="btn-list-style")
|
424 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False, elem_id="btn-list-style")
|
425 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False, visible=False, elem_id="btn-list-style")
|
426 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
427 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, elem_id="btn-list-style")
|
428 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False, elem_id="btn-list-style")
|
429 |
+
|
430 |
+
with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
|
431 |
+
temperature = gr.Slider(
|
432 |
+
minimum=0.0,
|
433 |
+
maximum=1.0,
|
434 |
+
value=0.95,
|
435 |
+
step=0.1,
|
436 |
+
interactive=True,
|
437 |
+
label="Temperature",
|
438 |
+
)
|
439 |
+
max_output_tokens = gr.Slider(
|
440 |
+
minimum=0,
|
441 |
+
maximum=1024,
|
442 |
+
value=512,
|
443 |
+
step=64,
|
444 |
+
interactive=True,
|
445 |
+
label="Max output tokens",
|
446 |
+
)
|
447 |
+
|
448 |
+
gr.Markdown(learn_more_markdown)
|
449 |
+
|
450 |
+
# Register listeners
|
451 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
452 |
+
upvote_btn.click(
|
453 |
+
upvote_last_response,
|
454 |
+
[state, model_selector],
|
455 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
456 |
+
)
|
457 |
+
downvote_btn.click(
|
458 |
+
downvote_last_response,
|
459 |
+
[state, model_selector],
|
460 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
461 |
+
)
|
462 |
+
flag_btn.click(
|
463 |
+
flag_last_response,
|
464 |
+
[state, model_selector],
|
465 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
466 |
+
)
|
467 |
+
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
468 |
+
http_bot,
|
469 |
+
[state, model_selector, temperature, max_output_tokens],
|
470 |
+
[state, chatbot] + btn_list,
|
471 |
+
)
|
472 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
473 |
+
|
474 |
+
model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
|
475 |
+
|
476 |
+
textbox.submit(
|
477 |
+
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
478 |
+
).then(
|
479 |
+
http_bot,
|
480 |
+
[state, model_selector, temperature, max_output_tokens],
|
481 |
+
[state, chatbot] + btn_list,
|
482 |
+
)
|
483 |
+
send_btn.click(
|
484 |
+
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
485 |
+
).then(
|
486 |
+
http_bot,
|
487 |
+
[state, model_selector, temperature, max_output_tokens],
|
488 |
+
[state, chatbot] + btn_list,
|
489 |
+
)
|
490 |
+
|
491 |
+
return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
|
492 |
+
|
493 |
+
|
494 |
+
def build_demo(models):
|
495 |
+
with gr.Blocks(
|
496 |
+
title="NeuralChat · Intel",
|
497 |
+
theme=gr.themes.Base(),
|
498 |
+
css=block_css,
|
499 |
+
) as demo:
|
500 |
+
url_params = gr.JSON(visible=False)
|
501 |
+
|
502 |
+
(
|
503 |
+
state,
|
504 |
+
model_selector,
|
505 |
+
chatbot,
|
506 |
+
textbox,
|
507 |
+
send_btn,
|
508 |
+
button_row,
|
509 |
+
parameter_row,
|
510 |
+
) = build_single_model_ui(models)
|
511 |
+
|
512 |
+
if model_list_mode == "once":
|
513 |
+
demo.load(
|
514 |
+
load_demo,
|
515 |
+
[url_params],
|
516 |
+
[
|
517 |
+
state,
|
518 |
+
model_selector,
|
519 |
+
chatbot,
|
520 |
+
textbox,
|
521 |
+
send_btn,
|
522 |
+
button_row,
|
523 |
+
parameter_row,
|
524 |
+
],
|
525 |
+
_js=get_window_url_params,
|
526 |
+
)
|
527 |
+
else:
|
528 |
+
raise ValueError(f"Unknown model list mode: {model_list_mode}")
|
529 |
+
|
530 |
+
return demo
|
531 |
+
|
532 |
+
|
533 |
+
if __name__ == "__main__":
|
534 |
+
|
535 |
+
controller_url = "http://mlp-dgx-01.sh.intel.com:21001"
|
536 |
+
host = "mlp-dgx-01.sh.intel.com"
|
537 |
+
# port = "mlp-dgx-01.sh.intel.com"
|
538 |
+
concurrency_count = 10
|
539 |
+
model_list_mode = "once"
|
540 |
+
share = True
|
541 |
+
moderate = False
|
542 |
+
|
543 |
+
set_global_vars(controller_url, moderate)
|
544 |
+
models = get_model_list(controller_url)
|
545 |
+
|
546 |
+
demo = build_demo(models)
|
547 |
+
demo.queue(
|
548 |
+
concurrency_count=concurrency_count, status_update_rate=10, api_open=False
|
549 |
+
).launch(
|
550 |
+
server_name=host, share=share, max_threads=200
|
551 |
+
)
|