erwannd's picture
Update app.py
337fc0b verified
raw
history blame
No virus
9.35 kB
import sys
import os
import argparse
import time
import subprocess
import gradio as gr
import llava.serve.gradio_web_server as gws
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=gws.block_css) as demo:
state = gr.State()
if not embed_mode:
gr.Markdown(gws.title_markdown)
with gr.Row():
with gr.Column(scale=3):
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=gws.models,
value=gws.models[0] if len(gws.models) > 0 else "",
interactive=True,
show_label=False,
container=False)
imagebox = gr.Image(type="pil")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image", visible=False)
if cur_dir is None:
cur_dir = os.path.dirname(os.path.abspath(__file__))
user_prompt = "Evaluate and explain if this chart is misleading"
gr.Examples(examples=[
[f"{cur_dir}/examples/bar_custom_1.png", user_prompt],
[f"{cur_dir}/examples/fox_news.jpeg", user_prompt],
[f"{cur_dir}/examples/bar_wiki.png", user_prompt],
], inputs=[imagebox, textbox])
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0, step=0.1, interactive=True, label="Temperature",)
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="LLaVA Chatbot",
height=650,
layout="panel",
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=False)
downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=False)
#stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=False)
clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear", interactive=False)
if not embed_mode:
gr.Markdown(gws.tos_markdown)
gr.Markdown(gws.learn_more_markdown)
url_params = gr.JSON(visible=False)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
gws.upvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
downvote_btn.click(
gws.downvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
flag_btn.click(
gws.flag_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn]
)
regenerate_btn.click(
gws.regenerate,
[state, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
gws.http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
clear_btn.click(
gws.clear_history,
None,
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
)
textbox.submit(
gws.add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
).then(
gws.http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
submit_btn.click(
gws.add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list
).then(
gws.http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
concurrency_limit=concurrency_count
)
if gws.args.model_list_mode == "once":
demo.load(
gws.load_demo,
[url_params],
[state, model_selector],
js=gws.get_window_url_params
)
elif gws.args.model_list_mode == "reload":
demo.load(
gws.load_demo_refresh_model_list,
None,
[state, model_selector],
queue=False
)
else:
raise ValueError(f"Unknown model list mode: {gws.args.model_list_mode}")
return demo
# Execute the pip install command with additional options
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
def start_controller():
print("Starting the controller")
controller_command = [
sys.executable,
"-m",
"llava.serve.controller",
"--host",
"0.0.0.0",
"--port",
"10000",
]
print(controller_command)
return subprocess.Popen(controller_command)
def start_worker(model_path: str, bits=16):
print(f"Starting the model worker for the model {model_path}")
model_name = model_path.strip("/").split("/")[-1]
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
if bits != 16:
model_name += f"-{bits}bit"
worker_command = [
sys.executable,
"-m",
"llava.serve.model_worker",
"--host",
"0.0.0.0",
"--controller",
"http://localhost:10000",
"--model-path",
model_path,
"--model-name",
model_name,
"--use-flash-attn",
]
if bits != 16:
worker_command += [f"--load-{bits}bit"]
print(worker_command)
return subprocess.Popen(worker_command)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
parser.add_argument("--concurrency-count", type=int, default=5)
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
gws.args = parser.parse_args()
gws.models = []
gws.title_markdown += """
ONLY WORKS WITH GPU! By default, we load the model with 4-bit quantization to make it fit in smaller hardwares. Set the environment variable `bits` to control the quantization.
Set the environment variable `model` to change the model:
[`liuhaotian/llava-v1.6-mistral-7b`](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b),
[`liuhaotian/llava-v1.6-vicuna-7b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b),
[`liuhaotian/llava-v1.6-vicuna-13b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-13b),
[`liuhaotian/llava-v1.6-34b`](https://huggingface.co/liuhaotian/llava-v1.6-34b).
"""
print(f"args: {gws.args}")
model_path = os.getenv("model", "liuhaotian/llava-v1.6-mistral-7b")
bits = int(os.getenv("bits", 4))
concurrency_count = int(os.getenv("concurrency_count", 5))
controller_proc = start_controller()
worker_proc = start_worker(model_path, bits=bits)
# Wait for worker and controller to start
time.sleep(10)
exit_status = 0
try:
demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
demo.queue(
status_update_rate=10,
api_open=False
).launch(
server_name=gws.args.host,
server_port=gws.args.port,
share=gws.args.share
)
except Exception as e:
print(e)
exit_status = 1
finally:
worker_proc.kill()
controller_proc.kill()
sys.exit(exit_status)