wenjiao commited on
Commit
454f02d
1 Parent(s): d371cb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +520 -153
app.py CHANGED
@@ -1,184 +1,551 @@
 
 
 
 
1
  import os
2
- import gradio as gr
3
- import numpy as np
4
- import random
5
- import torch
6
- import subprocess
7
  import time
 
 
 
8
  import requests
9
- import json
10
 
11
- import base64
12
- from io import BytesIO
13
- from PIL import Image
14
- from huggingface_hub import login
15
- from huggingface_hub.utils import (
16
- HfFolder
17
  )
 
 
 
 
 
 
 
 
 
18
 
19
- myip = os.environ["myip"]
20
- myport = os.environ["myport"]
21
-
22
- url = f"http://{myip}:{myport}"
23
 
24
- queue_size = 0
25
 
26
- def displayTextBox():
27
- global queue_size
28
- if queue_size > 4:
29
- return [gr.update(visible=False), gr.update(visible=True)]
30
- elif queue_size <= 4:
31
- return [gr.update(visible=True), gr.update(visible=False)]
32
 
 
 
 
33
 
34
- def set_msg():
35
- global queue_size
36
- que_high_msg = "The current traffic is high with " + str(queue_size) + " in the queue. Please wait a moment."
37
- que_normal_msg = "The current traffic is not high. You can submit your job now."
38
-
39
- if queue_size > int(os.environ["max_queue_size"]):
40
- return que_high_msg
41
- else:
42
- return que_normal_msg
43
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def img2img_generate(source_img, prompt, steps=25, strength=0.75, seed=42, guidance_scale=7.5):
46
 
47
- print('image-to-image')
48
- print("prompt: ", prompt)
49
- print("steps: ", steps)
50
- buffered = BytesIO()
51
- source_img.save(buffered, format="JPEG")
52
- img_b64 = base64.b64encode(buffered.getvalue())
53
- timestamp = int(time.time()*1000)
54
 
55
- data = {"source_img": img_b64.decode(), "prompt": prompt, "steps": steps,
56
- "guidance_scale": guidance_scale, "seed": seed, "strength": strength,
57
- "task_type": "1",
58
- "timestamp": timestamp, "user": os.environ.get("token", "")}
59
 
60
- start_time = time.time()
61
- global queue_size
62
- queue_size = queue_size + 1
63
- resp = requests.post(url, data=json.dumps(data))
64
- queue_size = queue_size - 1
65
 
66
- try:
67
- img_str = json.loads(resp.text)["img_str"]
68
- print("Compute node: ", json.loads(resp.text)["ip"])
69
- except:
70
- print('No inference result. Please check server connection')
71
- return None
72
-
73
- img_byte = base64.b64decode(img_str)
74
- img_io = BytesIO(img_byte) # convert image to file-like object
75
- img = Image.open(img_io) # img is now PIL Image object
76
- print("elapsed time: ", time.time() - start_time)
77
- return img
78
-
79
-
80
- def txt2img_generate(prompt, steps=25, seed=42, guidance_scale=7.5):
81
-
82
- print('text-to-image')
83
- print("prompt: ", prompt)
84
- print("steps: ", steps)
85
- timestamp = int(time.time()*1000)
86
- data = {"prompt": prompt,
87
- "steps": steps, "guidance_scale": guidance_scale, "seed": seed,
88
- "task_type": "0",
89
- "timestamp": timestamp, "user": os.environ.get("token", "")}
90
- start_time = time.time()
91
- global queue_size
92
- queue_size = queue_size + 1
93
- resp = requests.post(url, data=json.dumps(data))
94
- queue_size = queue_size - 1
95
- try:
96
- img_str = json.loads(resp.text)["img_str"]
97
- print("Compute node: ", json.loads(resp.text)["ip"])
98
- except:
99
- print('No inference result. Please check server connection')
100
- return None
101
-
102
- img_byte = base64.b64decode(img_str)
103
- img_io = BytesIO(img_byte) # convert image to file-like object
104
- img = Image.open(img_io) # img is now PIL Image object
105
- print("elapsed time: ", time.time() - start_time)
106
 
107
- return img
 
 
 
 
 
 
108
 
109
 
110
- md = """
111
- This demo shows the accelerated inference performance of a Stable Diffusion model on **Intel Xeon Gold 64xx (4th Gen Intel Xeon Scalable Processors codenamed Sapphire Rapids)**. Try it and generate photorealistic images from text! Please note that the demo is in **preview** under limited HW resources. We are committed to continue improving the demo and happy to hear your feedbacks. Thanks for your trying!
112
- You may also want to try creating your own Stable Diffusion with few-shot fine-tuning. Please refer to our <a href=\"https://medium.com/intel-analytics-software/personalized-stable-diffusion-with-few-shot-fine-tuning-on-a-single-cpu-f01a3316b13\">blog</a> and <a href=\"https://github.com/intel/neural-compressor/tree/master/examples/pytorch/diffusion_model/diffusers/textual_inversion\">code</a> available in <a href=\"https://github.com/intel/neural-compressor\">**Intel Neural Compressor**</a> and <a href=\"https://github.com/huggingface/diffusers\">**Hugging Face Diffusers**</a>.
 
 
 
 
113
  """
114
 
115
- legal = """
116
- Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex. Performance results are based on testing as of dates shown in configurations and may not reflect all publicly available updates. See backup for configuration details. No product or component can be absolutely secure.
117
- © Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.
118
- """
119
 
120
- details = """
121
- 4th Gen Intel Xeon Scalable Processor Inference. Test by Intel on 01/06/2023. 1 node, 1S, Intel(R) Xeon(R) Gold 64xx CPU @ 3.0GHz 32 cores and software with 512GB (8x64GB DDR5 4800 MT/s [4800 MT/s]), microcode 0x2a000080, HT on, Turbo on, Ubuntu 22.04.1 LTS, 5.15.0-1026-aws, 200G Amazon Elastic Block Store. Multiple nodes connected with Elastic Network Adapter (ENA). PyTorch Nightly build (2.0.0.dev20230105+cpu), Transformers 4.25.1, Diffusers 0.11.1, oneDNN v2.7.2.
122
- """
123
-
124
- css = '''
125
- .instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
126
- .arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
127
- #component-4, #component-3, #component-10{min-height: 0}
128
- .duplicate-button img{margin: 0}
129
- #mdStyle{font-size: 0.6rem}
130
- .generating.svelte-1w9161c { border: none }
131
- #txtGreenStyle {border: 2px solid #32ec48;}
132
- #txtOrangeStyle {border: 2px solid #e77718;}
133
- '''
134
-
135
- random_seed = random.randint(0, 2147483647)
136
-
137
- with gr.Blocks(css=css) as demo:
138
- gr.Markdown("# Stable Diffusion Inference Demo on 4th Gen Intel Xeon Scalable Processors")
139
- gr.Markdown(md)
140
-
141
- textBoxGreen = gr.Textbox(set_msg, every=1, label='Real-time Jobs in Queue', elem_id='txtGreenStyle', visible=True)
142
- textBoxOrange = gr.Textbox(set_msg, every=1, label='Real-time Jobs in Queue', elem_id='txtOrangeStyle', visible=False)
143
- textBoxGreen.change(displayTextBox, outputs = [textBoxGreen, textBoxOrange])
144
-
145
- with gr.Tab("Text-to-Image"):
146
- with gr.Row(visible=True) as text_to_image:
147
- with gr.Column():
148
- prompt = gr.inputs.Textbox(label='Prompt', default='a photo of an astronaut riding a horse on mars')
149
- inference_steps = gr.inputs.Slider(1, 100, label='Inference Steps - increase the steps for better quality (e.g., avoiding black image) ', default=20, step=1)
150
- seed = gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1)
151
- guidance_scale = gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=7.5, step=0.1)
152
- txt2img_button = gr.Button("Generate Image")
153
-
154
- with gr.Column():
155
- result_image = gr.Image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
 
157
 
158
- with gr.Tab("Image-to-Image text-guided generation"):
159
- with gr.Row(visible=True) as image_to_image:
160
- with gr.Column():
161
- source_img = gr.Image(source="upload", type="pil", value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg")
162
- # source_img = gr.Image(source="upload", type="pil")
163
- prompt_2 = gr.inputs.Textbox(label='Prompt', default='A fantasy landscape, trending on artstation')
164
- inference_steps_2 = gr.inputs.Slider(1, 100, label='Inference Steps - increase the steps for better quality (e.g., avoiding black image) ', default=20, step=1)
165
- seed_2 = gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1)
166
- guidance_scale_2 = gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=7.5, step=0.1)
167
- strength = gr.inputs.Slider(0.0, 1.0, label='Strength - adding more noise to it the larger the strength', default=0.75, step=0.01)
168
- img2img_button = gr.Button("Generate Image")
169
 
170
- with gr.Column():
171
- result_image_2 = gr.Image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- txt2img_button.click(fn=txt2img_generate, inputs=[prompt, inference_steps, seed, guidance_scale], outputs=[result_image])
174
 
175
- img2img_button.click(fn=img2img_generate, inputs=[source_img, prompt_2, inference_steps_2, strength, seed_2, guidance_scale_2], outputs=result_image_2)
 
 
176
 
177
- gr.Markdown("**Additional Test Configuration Details:**", elem_id='mdStyle')
178
- gr.Markdown(details, elem_id='mdStyle')
179
 
180
- gr.Markdown("**Notices and Disclaimers:**", elem_id='mdStyle')
181
- gr.Markdown(legal, elem_id='mdStyle')
182
 
183
- demo.queue(max_size=int(os.environ["max_job_size"]), concurrency_count=int(os.environ["max_job_size"])).launch(debug=True, show_api=False)
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from collections import defaultdict
3
+ import datetime
4
+ import json
5
  import os
 
 
 
 
 
6
  import time
7
+ import uuid
8
+
9
+ import gradio as gr
10
  import requests
 
11
 
12
+ from fastchat.conversation import (
13
+ Conversation,
14
+ compute_skip_echo_len,
15
+ SeparatorStyle,
 
 
16
  )
17
+ from fastchat.constants import LOGDIR
18
+ from fastchat.utils import (
19
+ build_logger,
20
+ server_error_msg,
21
+ violates_moderation,
22
+ moderation_msg,
23
+ )
24
+ from fastchat.serve.gradio_patch import Chatbot as grChatbot
25
+ from fastchat.serve.gradio_css import code_highlight_css
26
 
 
 
 
 
27
 
28
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
29
 
30
+ headers = {"User-Agent": "NeuralChat Client"}
 
 
 
 
 
31
 
32
+ no_change_btn = gr.Button.update()
33
+ enable_btn = gr.Button.update(interactive=True)
34
+ disable_btn = gr.Button.update(interactive=False)
35
 
36
+ controller_url = None
37
+ enable_moderation = False
 
 
 
 
 
 
 
38
 
39
+ conv_template_bf16 = Conversation(
40
+ system="A chat between a curious human and an artificial intelligence assistant. "
41
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
42
+ roles=("Human", "Assistant"),
43
+ messages=(),
44
+ offset=0,
45
+ sep_style=SeparatorStyle.SINGLE,
46
+ sep="\n",
47
+ sep2="</s>",
48
+ )
49
 
 
50
 
51
+ def set_global_vars(controller_url_, enable_moderation_):
52
+ global controller_url, enable_moderation
53
+ controller_url = controller_url_
54
+ enable_moderation = enable_moderation_
 
 
 
55
 
 
 
 
 
56
 
57
+ def get_conv_log_filename():
58
+ t = datetime.datetime.now()
59
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
60
+ return name
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ def get_model_list(controller_url):
64
+ ret = requests.post(controller_url + "/refresh_all_workers")
65
+ assert ret.status_code == 200
66
+ ret = requests.post(controller_url + "/list_models")
67
+ models = ret.json()["models"]
68
+ logger.info(f"Models: {models}")
69
+ return models
70
 
71
 
72
+ get_window_url_params = """
73
+ function() {
74
+ const params = new URLSearchParams(window.location.search);
75
+ url_params = Object.fromEntries(params);
76
+ console.log("url_params", url_params);
77
+ return url_params;
78
+ }
79
  """
80
 
 
 
 
 
81
 
82
+ def load_demo_single(models, url_params):
83
+ dropdown_update = gr.Dropdown.update(visible=True)
84
+ if "model" in url_params:
85
+ model = url_params["model"]
86
+ if model in models:
87
+ dropdown_update = gr.Dropdown.update(value=model, visible=True)
88
+
89
+ state = None
90
+ return (
91
+ state,
92
+ dropdown_update,
93
+ gr.Chatbot.update(visible=True),
94
+ gr.Textbox.update(visible=True),
95
+ gr.Button.update(visible=True),
96
+ gr.Row.update(visible=True),
97
+ gr.Accordion.update(visible=False),
98
+ )
99
+
100
+
101
+ def load_demo(url_params, request: gr.Request):
102
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
103
+ return load_demo_single(models, url_params)
104
+
105
+
106
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
107
+ with open(get_conv_log_filename(), "a") as fout:
108
+ data = {
109
+ "tstamp": round(time.time(), 4),
110
+ "type": vote_type,
111
+ "model": model_selector,
112
+ "state": state.dict(),
113
+ "ip": request.client.host,
114
+ }
115
+ fout.write(json.dumps(data) + "\n")
116
+
117
+
118
+ def upvote_last_response(state, model_selector, request: gr.Request):
119
+ logger.info(f"upvote. ip: {request.client.host}")
120
+ vote_last_response(state, "upvote", model_selector, request)
121
+ return ("",) + (disable_btn,) * 3
122
+
123
+
124
+ def downvote_last_response(state, model_selector, request: gr.Request):
125
+ logger.info(f"downvote. ip: {request.client.host}")
126
+ vote_last_response(state, "downvote", model_selector, request)
127
+ return ("",) + (disable_btn,) * 3
128
+
129
+
130
+ def flag_last_response(state, model_selector, request: gr.Request):
131
+ logger.info(f"flag. ip: {request.client.host}")
132
+ vote_last_response(state, "flag", model_selector, request)
133
+ return ("",) + (disable_btn,) * 3
134
+
135
+
136
+ def regenerate(state, request: gr.Request):
137
+ logger.info(f"regenerate. ip: {request.client.host}")
138
+ state.messages[-1][-1] = None
139
+ state.skip_next = False
140
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
141
+
142
+
143
+ def clear_history(request: gr.Request):
144
+ logger.info(f"clear_history. ip: {request.client.host}")
145
+ state = None
146
+ return (state, [], "") + (disable_btn,) * 5
147
+
148
+
149
+ def add_text(state, text, request: gr.Request):
150
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
151
+
152
+ if state is None:
153
+ state = conv_template_bf16.copy()
154
+
155
+ if len(text) <= 0:
156
+ state.skip_next = True
157
+ return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
158
+ if enable_moderation:
159
+ flagged = violates_moderation(text)
160
+ if flagged:
161
+ logger.info(f"violate moderation. ip: {request.client.host}. text: {text}")
162
+ state.skip_next = True
163
+ return (state, state.to_gradio_chatbot(), moderation_msg) + (
164
+ no_change_btn,
165
+ ) * 5
166
+
167
+ text = text[:1536] # Hard cut-off
168
+ state.append_message(state.roles[0], text)
169
+ state.append_message(state.roles[1], None)
170
+ state.skip_next = False
171
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
172
+
173
+
174
+ def post_process_code(code):
175
+ sep = "\n```"
176
+ if sep in code:
177
+ blocks = code.split(sep)
178
+ if len(blocks) % 2 == 1:
179
+ for i in range(1, len(blocks), 2):
180
+ blocks[i] = blocks[i].replace("\\_", "_")
181
+ code = sep.join(blocks)
182
+ return code
183
+
184
+
185
+ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Request):
186
+ logger.info(f"http_bot. ip: {request.client.host}")
187
+ start_tstamp = time.time()
188
+ model_name = model_selector
189
+ temperature = float(temperature)
190
+ max_new_tokens = int(max_new_tokens)
191
+
192
+ if state.skip_next:
193
+ # This generate call is skipped due to invalid inputs
194
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
195
+ return
196
+
197
+ if len(state.messages) == state.offset + 2:
198
+ # First round of conversation
199
+ new_state = conv_template_bf16.copy()
200
+ new_state.conv_id = uuid.uuid4().hex
201
+ new_state.model_name = state.model_name or model_selector
202
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
203
+ new_state.append_message(new_state.roles[1], None)
204
+ state = new_state
205
+
206
+ # Query worker address
207
+ ret = requests.post(
208
+ controller_url + "/get_worker_address", json={"model": model_name}
209
+ )
210
+ worker_addr = ret.json()["address"]
211
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
212
+
213
+ # No available worker
214
+ if worker_addr == "":
215
+ state.messages[-1][-1] = server_error_msg
216
+ yield (
217
+ state,
218
+ state.to_gradio_chatbot(),
219
+ disable_btn,
220
+ disable_btn,
221
+ disable_btn,
222
+ enable_btn,
223
+ enable_btn,
224
+ )
225
+ return
226
+
227
+ # Construct prompt
228
+ prompt = state.get_prompt()
229
+ skip_echo_len = compute_skip_echo_len(model_name, state, prompt)
230
+
231
+ # Make requests
232
+ pload = {
233
+ "model": model_name,
234
+ "prompt": prompt,
235
+ "temperature": temperature,
236
+ "max_new_tokens": max_new_tokens,
237
+ "stop": "</s>"
238
+ }
239
+ logger.info(f"==== request ====\n{pload}")
240
 
241
+ start_time = time.time()
242
 
243
+ state.messages[-1][-1] = "▌"
244
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
 
 
 
 
 
 
 
 
 
245
 
246
+ try:
247
+ # Stream output
248
+ response = requests.post(
249
+ worker_addr + "/worker_generate_stream",
250
+ headers=headers,
251
+ json=pload,
252
+ stream=True,
253
+ timeout=20,
254
+ )
255
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
256
+ if chunk:
257
+ data = json.loads(chunk.decode())
258
+ if data["error_code"] == 0:
259
+ output = data["text"][skip_echo_len:].strip()
260
+ output = post_process_code(output)
261
+ state.messages[-1][-1] = output + "▌"
262
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
263
+ else:
264
+ output = data["text"] + f" (error_code: {data['error_code']})"
265
+ state.messages[-1][-1] = output
266
+ yield (state, state.to_gradio_chatbot()) + (
267
+ disable_btn,
268
+ disable_btn,
269
+ disable_btn,
270
+ enable_btn,
271
+ enable_btn,
272
+ )
273
+ return
274
+ time.sleep(0.02)
275
+ except requests.exceptions.RequestException as e:
276
+ state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
277
+ yield (state, state.to_gradio_chatbot()) + (
278
+ disable_btn,
279
+ disable_btn,
280
+ disable_btn,
281
+ enable_btn,
282
+ enable_btn,
283
+ )
284
+ return
285
+
286
+ finish_tstamp = time.time() - start_time
287
+ elapsed_time = "\n✅generation elapsed time: {}s".format(round(finish_tstamp, 4))
288
+
289
+ state.messages[-1][-1] = state.messages[-1][-1][:-1] + elapsed_time
290
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
291
+
292
+ logger.info(f"{output}")
293
+
294
+ with open(get_conv_log_filename(), "a") as fout:
295
+ data = {
296
+ "tstamp": round(finish_tstamp, 4),
297
+ "type": "chat",
298
+ "model": model_name,
299
+ "gen_params": {
300
+ "temperature": temperature,
301
+ "max_new_tokens": max_new_tokens,
302
+ },
303
+ "start": round(start_tstamp, 4),
304
+ "finish": round(start_tstamp, 4),
305
+ "state": state.dict(),
306
+ "ip": request.client.host,
307
+ }
308
+ fout.write(json.dumps(data) + "\n")
309
+
310
+
311
+ block_css = (
312
+ code_highlight_css
313
+ + """
314
+ pre {
315
+ white-space: pre-wrap; /* Since CSS 2.1 */
316
+ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
317
+ white-space: -pre-wrap; /* Opera 4-6 */
318
+ white-space: -o-pre-wrap; /* Opera 7 */
319
+ word-wrap: break-word; /* Internet Explorer 5.5+ */
320
+ }
321
+ #notice_markdown th {
322
+ display: none;
323
+ }
324
+
325
+ #notice_markdown {
326
+ text-align: center;
327
+ background: #874bec;
328
+ padding: 1%;
329
+ }
330
+
331
+ #notice_markdown h1, #notice_markdown h4 {
332
+ color: #fff;
333
+ margin-top: 0;
334
+ }
335
+
336
+ gradio-app {
337
+ background: linear-gradient(to bottom, #ba97d8, #5400ff) !important;
338
+ padding: 3%;
339
+ }
340
+
341
+ .gradio-container {
342
+ margin: 0 auto !important;
343
+ width: 70% !important;
344
+ padding: 0 !important;
345
+ }
346
+
347
+ #chatbot {
348
+ border-style: solid;
349
+ overflow: visible;
350
+ margin: 1% 4%;
351
+ width: 90%;
352
+ box-shadow: 0 15px 15px -5px rgba(0, 0, 0, 0.2);
353
+ border: 1px solid #ddd;
354
+ }
355
+
356
+ #text-box-style, #btn-style {
357
+ width: 90%;
358
+ margin: 1% 4%;
359
+ }
360
+
361
+
362
+ .user, .bot {
363
+ width: 80% !important;
364
+
365
+ }
366
+
367
+ .bot {
368
+ white-space: pre-wrap !important;
369
+ line-height: 1.3 !important;
370
+ display: flex;
371
+ flex-direction: column;
372
+ justify-content: flex-start;
373
+
374
+ }
375
+
376
+ #btn-send-style {
377
+ background: rgb(0, 180, 50);
378
+ color: #fff;
379
+ }
380
+
381
+ #btn-list-style {
382
+ background: #eee0;
383
+ border: 1px solid #691ef7;
384
+ }
385
+ """
386
+ )
387
 
 
388
 
389
+ def build_single_model_ui(models):
390
+ notice_markdown = """
391
+ # 🤖 NeuralChat
392
 
393
+ #### deployed on 4th Gen Intel Xeon Scalable Processors codenamed Sapphire Rapids
 
394
 
395
+ """
 
396
 
397
+ learn_more_markdown = """
398
+ """
399
 
400
+ state = gr.State()
401
+ notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
402
+
403
+ with gr.Row(elem_id="model_selector_row", visible=False):
404
+ model_selector = gr.Dropdown(
405
+ choices=models,
406
+ value=models[0] if len(models) > 0 else "",
407
+ interactive=True,
408
+ show_label=False,
409
+ ).style(container=False)
410
+
411
+ chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
412
+ with gr.Row(elem_id="text-box-style"):
413
+ with gr.Column(scale=20):
414
+ textbox = gr.Textbox(
415
+ show_label=False,
416
+ placeholder="Enter text and press ENTER",
417
+ visible=False,
418
+ ).style(container=False)
419
+ with gr.Column(scale=1, min_width=50):
420
+ send_btn = gr.Button(value="Send", visible=False, elem_id="btn-send-style")
421
+
422
+ with gr.Row(visible=False, elem_id="btn-style") as button_row:
423
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False, elem_id="btn-list-style")
424
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False, elem_id="btn-list-style")
425
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False, visible=False, elem_id="btn-list-style")
426
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
427
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, elem_id="btn-list-style")
428
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False, elem_id="btn-list-style")
429
+
430
+ with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
431
+ temperature = gr.Slider(
432
+ minimum=0.0,
433
+ maximum=1.0,
434
+ value=0.95,
435
+ step=0.1,
436
+ interactive=True,
437
+ label="Temperature",
438
+ )
439
+ max_output_tokens = gr.Slider(
440
+ minimum=0,
441
+ maximum=1024,
442
+ value=512,
443
+ step=64,
444
+ interactive=True,
445
+ label="Max output tokens",
446
+ )
447
+
448
+ gr.Markdown(learn_more_markdown)
449
+
450
+ # Register listeners
451
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
452
+ upvote_btn.click(
453
+ upvote_last_response,
454
+ [state, model_selector],
455
+ [textbox, upvote_btn, downvote_btn, flag_btn],
456
+ )
457
+ downvote_btn.click(
458
+ downvote_last_response,
459
+ [state, model_selector],
460
+ [textbox, upvote_btn, downvote_btn, flag_btn],
461
+ )
462
+ flag_btn.click(
463
+ flag_last_response,
464
+ [state, model_selector],
465
+ [textbox, upvote_btn, downvote_btn, flag_btn],
466
+ )
467
+ regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
468
+ http_bot,
469
+ [state, model_selector, temperature, max_output_tokens],
470
+ [state, chatbot] + btn_list,
471
+ )
472
+ clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
473
+
474
+ model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
475
+
476
+ textbox.submit(
477
+ add_text, [state, textbox], [state, chatbot, textbox] + btn_list
478
+ ).then(
479
+ http_bot,
480
+ [state, model_selector, temperature, max_output_tokens],
481
+ [state, chatbot] + btn_list,
482
+ )
483
+ send_btn.click(
484
+ add_text, [state, textbox], [state, chatbot, textbox] + btn_list
485
+ ).then(
486
+ http_bot,
487
+ [state, model_selector, temperature, max_output_tokens],
488
+ [state, chatbot] + btn_list,
489
+ )
490
+
491
+ return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row
492
+
493
+
494
+ def build_demo(models):
495
+ with gr.Blocks(
496
+ title="NeuralChat · Intel",
497
+ theme=gr.themes.Base(),
498
+ css=block_css,
499
+ ) as demo:
500
+ url_params = gr.JSON(visible=False)
501
+
502
+ (
503
+ state,
504
+ model_selector,
505
+ chatbot,
506
+ textbox,
507
+ send_btn,
508
+ button_row,
509
+ parameter_row,
510
+ ) = build_single_model_ui(models)
511
+
512
+ if model_list_mode == "once":
513
+ demo.load(
514
+ load_demo,
515
+ [url_params],
516
+ [
517
+ state,
518
+ model_selector,
519
+ chatbot,
520
+ textbox,
521
+ send_btn,
522
+ button_row,
523
+ parameter_row,
524
+ ],
525
+ _js=get_window_url_params,
526
+ )
527
+ else:
528
+ raise ValueError(f"Unknown model list mode: {model_list_mode}")
529
+
530
+ return demo
531
+
532
+
533
+ if __name__ == "__main__":
534
+
535
+ controller_url = "http://mlp-dgx-01.sh.intel.com:21001"
536
+ host = "mlp-dgx-01.sh.intel.com"
537
+ # port = "mlp-dgx-01.sh.intel.com"
538
+ concurrency_count = 10
539
+ model_list_mode = "once"
540
+ share = True
541
+ moderate = False
542
+
543
+ set_global_vars(controller_url, moderate)
544
+ models = get_model_list(controller_url)
545
+
546
+ demo = build_demo(models)
547
+ demo.queue(
548
+ concurrency_count=concurrency_count, status_update_rate=10, api_open=False
549
+ ).launch(
550
+ server_name=host, share=share, max_threads=200
551
+ )