IELTS8 commited on
Commit
5be8754
β€’
1 Parent(s): 01dda6f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +606 -0
app.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import time
9
+
10
+ import gradio as gr
11
+ import requests
12
+
13
+ from llava.constants import LOGDIR
14
+ from llava.conversation import SeparatorStyle, conv_templates, default_conversation
15
+ from llava.utils import (
16
+ build_logger,
17
+ moderation_msg,
18
+ server_error_msg,
19
+ violates_moderation,
20
+ )
21
+
22
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
23
+
24
+ headers = {"User-Agent": "LLaVA Client"}
25
+
26
+ no_change_btn = gr.Button.update()
27
+ enable_btn = gr.Button.update(interactive=True)
28
+ disable_btn = gr.Button.update(interactive=False)
29
+
30
+ priority = {
31
+ "vicuna-13b": "aaaaaaa",
32
+ "koala-13b": "aaaaaab",
33
+ }
34
+
35
+
36
+ def get_conv_log_filename():
37
+ t = datetime.datetime.now()
38
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
39
+ return name
40
+
41
+
42
+ def get_model_list():
43
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
44
+ assert ret.status_code == 200
45
+ ret = requests.post(args.controller_url + "/list_models")
46
+ models = ret.json()["models"]
47
+ models.sort(key=lambda x: priority.get(x, x))
48
+ logger.info(f"Models: {models}")
49
+ return models
50
+
51
+
52
+ get_window_url_params = """
53
+ function() {
54
+ const params = new URLSearchParams(window.location.search);
55
+ url_params = Object.fromEntries(params);
56
+ console.log(url_params);
57
+ return url_params;
58
+ }
59
+ """
60
+
61
+
62
+ def load_demo(url_params, request: gr.Request):
63
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
64
+
65
+ dropdown_update = gr.Dropdown.update(visible=True)
66
+ if "model" in url_params:
67
+ model = url_params["model"]
68
+ if model in models:
69
+ dropdown_update = gr.Dropdown.update(value=model, visible=True)
70
+
71
+ state = default_conversation.copy()
72
+ return state, dropdown_update
73
+
74
+
75
+ def load_demo_refresh_model_list(request: gr.Request):
76
+ logger.info(f"load_demo. ip: {request.client.host}")
77
+ models = get_model_list()
78
+ state = default_conversation.copy()
79
+
80
+ models_downloaded = True if models else False
81
+
82
+ model_dropdown_kwargs = {
83
+ "choices": [],
84
+ "value": "Downloading the models...",
85
+ "interactive": models_downloaded,
86
+ }
87
+
88
+ if models_downloaded:
89
+ model_dropdown_kwargs["choices"] = models
90
+ # model_dropdown_kwargs["value"] = models[0]
91
+ model_dropdown_kwargs["value"] = 'MaViLa-13b'
92
+
93
+ models_dropdown_update = gr.Dropdown.update(**model_dropdown_kwargs)
94
+
95
+ send_button_update = gr.Button.update(
96
+ interactive=models_downloaded,
97
+ )
98
+
99
+ return state, models_dropdown_update, send_button_update
100
+
101
+
102
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
103
+ with open(get_conv_log_filename(), "a") as fout:
104
+ data = {
105
+ "tstamp": round(time.time(), 4),
106
+ "type": vote_type,
107
+ "model": model_selector,
108
+ "state": state.dict(),
109
+ "ip": request.client.host,
110
+ }
111
+ fout.write(json.dumps(data) + "\n")
112
+
113
+
114
+ def upvote_last_response(state, model_selector, request: gr.Request):
115
+ logger.info(f"upvote. ip: {request.client.host}")
116
+ vote_last_response(state, "upvote", model_selector, request)
117
+ return ("",) + (disable_btn,) * 3
118
+
119
+
120
+ def downvote_last_response(state, model_selector, request: gr.Request):
121
+ logger.info(f"downvote. ip: {request.client.host}")
122
+ vote_last_response(state, "downvote", model_selector, request)
123
+ return ("",) + (disable_btn,) * 3
124
+
125
+
126
+ def flag_last_response(state, model_selector, request: gr.Request):
127
+ logger.info(f"flag. ip: {request.client.host}")
128
+ vote_last_response(state, "flag", model_selector, request)
129
+ return ("",) + (disable_btn,) * 3
130
+
131
+
132
+ def regenerate(state, image_process_mode, request: gr.Request):
133
+ logger.info(f"regenerate. ip: {request.client.host}")
134
+ state.messages[-1][-1] = None
135
+ prev_human_msg = state.messages[-2]
136
+ if type(prev_human_msg[1]) in (tuple, list):
137
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
138
+ state.skip_next = False
139
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
140
+
141
+
142
+ def clear_history(request: gr.Request):
143
+ logger.info(f"clear_history. ip: {request.client.host}")
144
+ state = default_conversation.copy()
145
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
146
+
147
+
148
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
149
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
150
+ if len(text) <= 0 and image is None:
151
+ state.skip_next = True
152
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
153
+ if args.moderate:
154
+ flagged = violates_moderation(text)
155
+ if flagged:
156
+ state.skip_next = True
157
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
158
+ no_change_btn,
159
+ ) * 5
160
+
161
+ text = text[:1536] # Hard cut-off
162
+ if image is not None:
163
+ text = text[:1200] # Hard cut-off for images
164
+ if "<image>" not in text:
165
+ # text = '<Image><image></Image>' + text
166
+ text = text + "\n<image>"
167
+ text = (text, image, image_process_mode)
168
+ if len(state.get_images(return_pil=True)) > 0:
169
+ state = default_conversation.copy()
170
+ state.append_message(state.roles[0], text)
171
+ state.append_message(state.roles[1], None)
172
+ state.skip_next = False
173
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
174
+
175
+
176
+ def http_bot(
177
+ state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request
178
+ ):
179
+ logger.info(f"http_bot. ip: {request.client.host}")
180
+ start_tstamp = time.time()
181
+ model_name = model_selector
182
+
183
+ if state.skip_next:
184
+ # This generate call is skipped due to invalid inputs
185
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
186
+ return
187
+
188
+ if len(state.messages) == state.offset + 2:
189
+ # First round of conversation
190
+ if "llava" in model_name.lower():
191
+ if "llama-2" in model_name.lower():
192
+ template_name = "llava_llama_2"
193
+ elif "v1" in model_name.lower():
194
+ if "mmtag" in model_name.lower():
195
+ template_name = "v1_mmtag"
196
+ elif (
197
+ "plain" in model_name.lower()
198
+ and "finetune" not in model_name.lower()
199
+ ):
200
+ template_name = "v1_mmtag"
201
+ else:
202
+ template_name = "llava_v1"
203
+ elif "mpt" in model_name.lower():
204
+ template_name = "mpt"
205
+ else:
206
+ if "mmtag" in model_name.lower():
207
+ template_name = "v0_mmtag"
208
+ elif (
209
+ "plain" in model_name.lower()
210
+ and "finetune" not in model_name.lower()
211
+ ):
212
+ template_name = "v0_mmtag"
213
+ else:
214
+ template_name = "llava_v0"
215
+ elif "mpt" in model_name:
216
+ template_name = "mpt_text"
217
+ elif "llama-2" in model_name:
218
+ template_name = "llama_2"
219
+ else:
220
+ template_name = "vicuna_v1"
221
+ new_state = conv_templates[template_name].copy()
222
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
223
+ new_state.append_message(new_state.roles[1], None)
224
+ state = new_state
225
+
226
+ # Query worker address
227
+ controller_url = args.controller_url
228
+ ret = requests.post(
229
+ controller_url + "/get_worker_address", json={"model": model_name}
230
+ )
231
+ worker_addr = ret.json()["address"]
232
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
233
+
234
+ # No available worker
235
+ if worker_addr == "":
236
+ state.messages[-1][-1] = server_error_msg
237
+ yield (
238
+ state,
239
+ state.to_gradio_chatbot(),
240
+ disable_btn,
241
+ disable_btn,
242
+ disable_btn,
243
+ enable_btn,
244
+ enable_btn,
245
+ )
246
+ return
247
+
248
+ # Construct prompt
249
+ prompt = state.get_prompt()
250
+
251
+ all_images = state.get_images(return_pil=True)
252
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
253
+ for image, hash in zip(all_images, all_image_hash):
254
+ t = datetime.datetime.now()
255
+ filename = os.path.join(
256
+ LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
257
+ )
258
+ if not os.path.isfile(filename):
259
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
260
+ image.save(filename)
261
+
262
+ # Make requests
263
+ pload = {
264
+ "model": model_name,
265
+ "prompt": prompt,
266
+ "temperature": float(temperature),
267
+ "top_p": float(top_p),
268
+ "max_new_tokens": min(int(max_new_tokens), 1536),
269
+ "stop": state.sep
270
+ if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
271
+ else state.sep2,
272
+ "images": f"List of {len(state.get_images())} images: {all_image_hash}",
273
+ }
274
+ logger.info(f"==== request ====\n{pload}")
275
+
276
+ pload["images"] = state.get_images()
277
+
278
+ state.messages[-1][-1] = "β–Œ"
279
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
280
+
281
+ try:
282
+ # Stream output
283
+ response = requests.post(
284
+ worker_addr + "/worker_generate_stream",
285
+ headers=headers,
286
+ json=pload,
287
+ stream=True,
288
+ timeout=10,
289
+ )
290
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
291
+ if chunk:
292
+ data = json.loads(chunk.decode())
293
+ if data["error_code"] == 0:
294
+ output = data["text"][len(prompt) :].strip()
295
+ state.messages[-1][-1] = output + "β–Œ"
296
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
297
+ else:
298
+ output = data["text"] + f" (error_code: {data['error_code']})"
299
+ state.messages[-1][-1] = output
300
+ yield (state, state.to_gradio_chatbot()) + (
301
+ disable_btn,
302
+ disable_btn,
303
+ disable_btn,
304
+ enable_btn,
305
+ enable_btn,
306
+ )
307
+ return
308
+ time.sleep(0.03)
309
+ except requests.exceptions.RequestException as e:
310
+ state.messages[-1][-1] = server_error_msg
311
+ yield (state, state.to_gradio_chatbot()) + (
312
+ disable_btn,
313
+ disable_btn,
314
+ disable_btn,
315
+ enable_btn,
316
+ enable_btn,
317
+ )
318
+ return
319
+
320
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
321
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
322
+
323
+ finish_tstamp = time.time()
324
+ logger.info(f"{output}")
325
+
326
+ with open(get_conv_log_filename(), "a") as fout:
327
+ data = {
328
+ "tstamp": round(finish_tstamp, 4),
329
+ "type": "chat",
330
+ "model": model_name,
331
+ "start": round(start_tstamp, 4),
332
+ "finish": round(start_tstamp, 4),
333
+ "state": state.dict(),
334
+ "images": all_image_hash,
335
+ "ip": request.client.host,
336
+ }
337
+ fout.write(json.dumps(data) + "\n")
338
+
339
+
340
+ title_markdown = ("""
341
+ # 🏭 MaViLa: Manufacturing Vision Language Model
342
+ [[Model]](https://huggingface.co/IELTS8/MaViLa-13b)
343
+ """)
344
+
345
+
346
+ tos_markdown = """
347
+ ### Terms of use
348
+ By using this service, users are required to agree to the following terms:
349
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
350
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
351
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
352
+ """
353
+
354
+
355
+ learn_more_markdown = """
356
+ ### License
357
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
358
+ """
359
+
360
+ block_css = """
361
+ #buttons button {
362
+ min-width: min(120px,100%);
363
+ }
364
+ """
365
+
366
+
367
+ def build_demo(embed_mode):
368
+ models = get_model_list()
369
+
370
+ textbox = gr.Textbox(
371
+ show_label=False, placeholder="Enter text and press ENTER", container=False
372
+ )
373
+ with gr.Blocks(title="MaViLa", theme=gr.themes.Default(), css=block_css) as demo:
374
+ state = gr.State(default_conversation.copy())
375
+
376
+ if not embed_mode:
377
+ gr.Markdown(title_markdown)
378
+
379
+ with gr.Row():
380
+ with gr.Column(scale=3):
381
+ with gr.Row(elem_id="model_selector_row"):
382
+ model_selector = gr.Dropdown(
383
+ choices=models,
384
+ value=models[0] if models else "Downloading the models...",
385
+ interactive=True if models else False,
386
+ show_label=False,
387
+ container=False,
388
+ )
389
+
390
+ imagebox = gr.Image(type="pil")
391
+ image_process_mode = gr.Radio(
392
+ ["Crop", "Resize", "Pad", "Default"],
393
+ value="Default",
394
+ label="Preprocess for non-square image",
395
+ visible=False,
396
+ )
397
+
398
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
399
+ gr.Examples(examples=[
400
+ [f"{cur_dir}/examples/blobs.jpg", "What type of defect can be observed in the image?"],
401
+ [f"{cur_dir}/examples/spaghetti.jpg", "Can you see any defects in the part shown in the image?"],
402
+ ], inputs=[imagebox, textbox])
403
+
404
+ with gr.Accordion("Parameters", open=False) as parameter_row:
405
+ temperature = gr.Slider(
406
+ minimum=0.0,
407
+ maximum=1.0,
408
+ value=0.6,
409
+ step=0.1,
410
+ interactive=True,
411
+ label="Temperature",
412
+ )
413
+ top_p = gr.Slider(
414
+ minimum=0.0,
415
+ maximum=1.0,
416
+ value=0.7,
417
+ step=0.1,
418
+ interactive=True,
419
+ label="Top P",
420
+ )
421
+ max_output_tokens = gr.Slider(
422
+ minimum=0,
423
+ maximum=1024,
424
+ value=512,
425
+ step=64,
426
+ interactive=True,
427
+ label="Max output tokens",
428
+ )
429
+
430
+ with gr.Column(scale=8):
431
+ chatbot = gr.Chatbot(
432
+ elem_id="chatbot", label="MaViLa Chatbot", height=550
433
+ )
434
+ with gr.Row():
435
+ with gr.Column(scale=8):
436
+ textbox.render()
437
+ with gr.Column(scale=1, min_width=50):
438
+ submit_btn = gr.Button(
439
+ value="Send", variant="primary", interactive=False
440
+ )
441
+ with gr.Row(elem_id="buttons") as button_row:
442
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
443
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
444
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
445
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
446
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
447
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear history", interactive=False)
448
+
449
+ if not embed_mode:
450
+ gr.Markdown(tos_markdown)
451
+ gr.Markdown(learn_more_markdown)
452
+ url_params = gr.JSON(visible=False)
453
+
454
+ # Register listeners
455
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
456
+ upvote_btn.click(
457
+ upvote_last_response,
458
+ [state, model_selector],
459
+ [textbox, upvote_btn, downvote_btn, flag_btn],
460
+ )
461
+ downvote_btn.click(
462
+ downvote_last_response,
463
+ [state, model_selector],
464
+ [textbox, upvote_btn, downvote_btn, flag_btn],
465
+ )
466
+ flag_btn.click(
467
+ flag_last_response,
468
+ [state, model_selector],
469
+ [textbox, upvote_btn, downvote_btn, flag_btn],
470
+ )
471
+ regenerate_btn.click(
472
+ regenerate,
473
+ [state, image_process_mode],
474
+ [state, chatbot, textbox, imagebox] + btn_list,
475
+ ).then(
476
+ http_bot,
477
+ [state, model_selector, temperature, top_p, max_output_tokens],
478
+ [state, chatbot] + btn_list,
479
+ )
480
+ clear_btn.click(
481
+ clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
482
+ )
483
+
484
+ textbox.submit(
485
+ add_text,
486
+ [state, textbox, imagebox, image_process_mode],
487
+ [state, chatbot, textbox, imagebox] + btn_list,
488
+ ).then(
489
+ http_bot,
490
+ [state, model_selector, temperature, top_p, max_output_tokens],
491
+ [state, chatbot] + btn_list,
492
+ )
493
+ submit_btn.click(
494
+ add_text,
495
+ [state, textbox, imagebox, image_process_mode],
496
+ [state, chatbot, textbox, imagebox] + btn_list,
497
+ ).then(
498
+ http_bot,
499
+ [state, model_selector, temperature, top_p, max_output_tokens],
500
+ [state, chatbot] + btn_list,
501
+ )
502
+
503
+ if args.model_list_mode == "once":
504
+ demo.load(
505
+ load_demo,
506
+ [url_params],
507
+ [state, model_selector],
508
+ _js=get_window_url_params,
509
+ )
510
+ elif args.model_list_mode == "reload":
511
+ demo.load(
512
+ load_demo_refresh_model_list, None, [state, model_selector, submit_btn]
513
+ )
514
+ else:
515
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
516
+
517
+ return demo
518
+
519
+
520
+ def start_controller():
521
+ logger.info("Starting the controller")
522
+ controller_command = [
523
+ "python",
524
+ "-m",
525
+ "llava.serve.controller",
526
+ "--host",
527
+ "0.0.0.0",
528
+ "--port",
529
+ "10000",
530
+ ]
531
+ return subprocess.Popen(controller_command)
532
+
533
+
534
+ def start_worker(model_path: str, bits=16):
535
+ logger.info(f"Starting the model worker for the model {model_path}")
536
+ model_name = model_path.strip("/").split("/")[-1]
537
+ assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
538
+ if bits != 16:
539
+ model_name += f"-{bits}bit"
540
+ worker_command = [
541
+ "python",
542
+ "-m",
543
+ "llava.serve.model_worker",
544
+ "--host",
545
+ "0.0.0.0",
546
+ "--controller",
547
+ "http://localhost:10000",
548
+ "--model-path",
549
+ model_path,
550
+ "--model-name",
551
+ model_name,
552
+ ]
553
+ if bits != 16:
554
+ worker_command += [f"--load-{bits}bit"]
555
+ return subprocess.Popen(worker_command)
556
+
557
+
558
+ def get_args():
559
+ parser = argparse.ArgumentParser()
560
+ parser.add_argument("--host", type=str, default="0.0.0.0")
561
+ parser.add_argument("--port", type=int)
562
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
563
+ parser.add_argument("--concurrency-count", type=int, default=8)
564
+ parser.add_argument(
565
+ "--model-list-mode", type=str, default="reload", choices=["once", "reload"]
566
+ )
567
+ parser.add_argument("--share", action="store_true")
568
+ parser.add_argument("--moderate", action="store_true")
569
+ parser.add_argument("--embed", action="store_true")
570
+
571
+ args = parser.parse_args()
572
+
573
+ return args
574
+
575
+
576
+ def start_demo(args):
577
+ demo = build_demo(args.embed)
578
+ demo.queue(
579
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
580
+ ).launch(server_name=args.host, server_port=args.port, share=args.share)
581
+
582
+
583
+ if __name__ == "__main__":
584
+ args = get_args()
585
+ logger.info(f"args: {args}")
586
+
587
+ model_path = "IELTS8/MaViLa-13b"
588
+ bits = int(os.getenv("bits", 4))
589
+
590
+ controller_proc = start_controller()
591
+ worker_proc = start_worker(model_path, bits=bits)
592
+
593
+ # Wait for worker and controller to start
594
+ time.sleep(10)
595
+
596
+ exit_status = 0
597
+ try:
598
+ start_demo(args)
599
+ except Exception as e:
600
+ print(e)
601
+ exit_status = 1
602
+ finally:
603
+ worker_proc.kill()
604
+ controller_proc.kill()
605
+
606
+ sys.exit(exit_status)