|
import sys |
|
import os |
|
import argparse |
|
import time |
|
import subprocess |
|
|
|
import gradio as gr |
|
import llava.serve.gradio_web_server as gws |
|
|
|
def build_demo(embed_mode, cur_dir=None, concurrency_count=10): |
|
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) |
|
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=gws.block_css) as demo: |
|
state = gr.State() |
|
|
|
if not embed_mode: |
|
gr.Markdown(gws.title_markdown) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Row(elem_id="model_selector_row"): |
|
model_selector = gr.Dropdown( |
|
choices=gws.models, |
|
value=gws.models[0] if len(gws.models) > 0 else "", |
|
interactive=True, |
|
show_label=False, |
|
container=False) |
|
|
|
imagebox = gr.Image(type="pil") |
|
image_process_mode = gr.Radio( |
|
["Crop", "Resize", "Pad", "Default"], |
|
value="Default", |
|
label="Preprocess for non-square image", visible=False) |
|
|
|
if cur_dir is None: |
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
user_prompt = "Evaluate and explain if this chart is misleading" |
|
gr.Examples(examples=[ |
|
[f"{cur_dir}/examples/bar_custom_1.png", user_prompt], |
|
[f"{cur_dir}/examples/fox_news.jpeg", user_prompt], |
|
[f"{cur_dir}/examples/bar_wiki.png", user_prompt], |
|
], inputs=[imagebox, textbox]) |
|
|
|
with gr.Accordion("Parameters", open=False) as parameter_row: |
|
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0, step=0.1, interactive=True, label="Temperature",) |
|
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) |
|
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) |
|
|
|
with gr.Column(scale=8): |
|
chatbot = gr.Chatbot( |
|
elem_id="chatbot", |
|
label="LLaVA Chatbot", |
|
height=650, |
|
layout="panel", |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=8): |
|
textbox.render() |
|
with gr.Column(scale=1, min_width=50): |
|
submit_btn = gr.Button(value="Send", variant="primary") |
|
with gr.Row(elem_id="buttons") as button_row: |
|
upvote_btn = gr.Button(value="๐ Upvote", interactive=False) |
|
downvote_btn = gr.Button(value="๐ Downvote", interactive=False) |
|
flag_btn = gr.Button(value="โ ๏ธ Flag", interactive=False) |
|
|
|
regenerate_btn = gr.Button(value="๐ Regenerate", interactive=False) |
|
clear_btn = gr.Button(value="๐๏ธ Clear", interactive=False) |
|
|
|
if not embed_mode: |
|
gr.Markdown(gws.tos_markdown) |
|
gr.Markdown(gws.learn_more_markdown) |
|
url_params = gr.JSON(visible=False) |
|
|
|
|
|
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] |
|
upvote_btn.click( |
|
gws.upvote_last_response, |
|
[state, model_selector], |
|
[textbox, upvote_btn, downvote_btn, flag_btn] |
|
) |
|
downvote_btn.click( |
|
gws.downvote_last_response, |
|
[state, model_selector], |
|
[textbox, upvote_btn, downvote_btn, flag_btn] |
|
) |
|
flag_btn.click( |
|
gws.flag_last_response, |
|
[state, model_selector], |
|
[textbox, upvote_btn, downvote_btn, flag_btn] |
|
) |
|
|
|
regenerate_btn.click( |
|
gws.regenerate, |
|
[state, image_process_mode], |
|
[state, chatbot, textbox, imagebox] + btn_list |
|
).then( |
|
gws.http_bot, |
|
[state, model_selector, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
concurrency_limit=concurrency_count |
|
) |
|
|
|
clear_btn.click( |
|
gws.clear_history, |
|
None, |
|
[state, chatbot, textbox, imagebox] + btn_list, |
|
queue=False |
|
) |
|
|
|
textbox.submit( |
|
gws.add_text, |
|
[state, textbox, imagebox, image_process_mode], |
|
[state, chatbot, textbox, imagebox] + btn_list, |
|
queue=False |
|
).then( |
|
gws.http_bot, |
|
[state, model_selector, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
concurrency_limit=concurrency_count |
|
) |
|
|
|
submit_btn.click( |
|
gws.add_text, |
|
[state, textbox, imagebox, image_process_mode], |
|
[state, chatbot, textbox, imagebox] + btn_list |
|
).then( |
|
gws.http_bot, |
|
[state, model_selector, temperature, top_p, max_output_tokens], |
|
[state, chatbot] + btn_list, |
|
concurrency_limit=concurrency_count |
|
) |
|
|
|
if gws.args.model_list_mode == "once": |
|
demo.load( |
|
gws.load_demo, |
|
[url_params], |
|
[state, model_selector], |
|
js=gws.get_window_url_params |
|
) |
|
elif gws.args.model_list_mode == "reload": |
|
demo.load( |
|
gws.load_demo_refresh_model_list, |
|
None, |
|
[state, model_selector], |
|
queue=False |
|
) |
|
else: |
|
raise ValueError(f"Unknown model list mode: {gws.args.model_list_mode}") |
|
|
|
return demo |
|
|
|
|
|
|
|
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U']) |
|
|
|
|
|
def start_controller(): |
|
print("Starting the controller") |
|
controller_command = [ |
|
sys.executable, |
|
"-m", |
|
"llava.serve.controller", |
|
"--host", |
|
"0.0.0.0", |
|
"--port", |
|
"10000", |
|
] |
|
print(controller_command) |
|
return subprocess.Popen(controller_command) |
|
|
|
|
|
def start_worker(model_path: str, bits=16): |
|
print(f"Starting the model worker for the model {model_path}") |
|
model_name = model_path.strip("/").split("/")[-1] |
|
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit." |
|
if bits != 16: |
|
model_name += f"-{bits}bit" |
|
worker_command = [ |
|
sys.executable, |
|
"-m", |
|
"llava.serve.model_worker", |
|
"--host", |
|
"0.0.0.0", |
|
"--controller", |
|
"http://localhost:10000", |
|
"--model-path", |
|
model_path, |
|
"--model-name", |
|
model_name, |
|
"--use-flash-attn", |
|
] |
|
if bits != 16: |
|
worker_command += [f"--load-{bits}bit"] |
|
print(worker_command) |
|
return subprocess.Popen(worker_command) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
parser.add_argument("--port", type=int) |
|
parser.add_argument("--controller-url", type=str, default="http://localhost:10000") |
|
parser.add_argument("--concurrency-count", type=int, default=5) |
|
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"]) |
|
parser.add_argument("--share", action="store_true") |
|
parser.add_argument("--moderate", action="store_true") |
|
parser.add_argument("--embed", action="store_true") |
|
gws.args = parser.parse_args() |
|
gws.models = [] |
|
|
|
gws.title_markdown += """ |
|
|
|
ONLY WORKS WITH GPU! By default, we load the model with 4-bit quantization to make it fit in smaller hardwares. Set the environment variable `bits` to control the quantization. |
|
|
|
Set the environment variable `model` to change the model: |
|
[`liuhaotian/llava-v1.6-mistral-7b`](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b), |
|
[`liuhaotian/llava-v1.6-vicuna-7b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b), |
|
[`liuhaotian/llava-v1.6-vicuna-13b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-13b), |
|
[`liuhaotian/llava-v1.6-34b`](https://huggingface.co/liuhaotian/llava-v1.6-34b). |
|
""" |
|
|
|
print(f"args: {gws.args}") |
|
|
|
model_path = os.getenv("model", "liuhaotian/llava-v1.6-mistral-7b") |
|
bits = int(os.getenv("bits", 4)) |
|
concurrency_count = int(os.getenv("concurrency_count", 5)) |
|
|
|
controller_proc = start_controller() |
|
worker_proc = start_worker(model_path, bits=bits) |
|
|
|
|
|
time.sleep(10) |
|
|
|
exit_status = 0 |
|
try: |
|
demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count) |
|
demo.queue( |
|
status_update_rate=10, |
|
api_open=False |
|
).launch( |
|
server_name=gws.args.host, |
|
server_port=gws.args.port, |
|
share=gws.args.share |
|
) |
|
|
|
except Exception as e: |
|
print(e) |
|
exit_status = 1 |
|
finally: |
|
worker_proc.kill() |
|
controller_proc.kill() |
|
|
|
sys.exit(exit_status) |
|
|