import sys import os import argparse import time import subprocess import gradio as gr import llava.serve.gradio_web_server as gws def build_demo(embed_mode, cur_dir=None, concurrency_count=10): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=gws.block_css) as demo: state = gr.State() if not embed_mode: gr.Markdown(gws.title_markdown) with gr.Row(): with gr.Column(scale=3): with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=gws.models, value=gws.models[0] if len(gws.models) > 0 else "", interactive=True, show_label=False, container=False) imagebox = gr.Image(type="pil") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False) if cur_dir is None: cur_dir = os.path.dirname(os.path.abspath(__file__)) user_prompt = "Evaluate and explain if this chart is misleading" gr.Examples(examples=[ [f"{cur_dir}/examples/bar_custom_1.png", user_prompt], [f"{cur_dir}/examples/fox_news.jpeg", user_prompt], [f"{cur_dir}/examples/bar_wiki.png", user_prompt], ], inputs=[imagebox, textbox]) with gr.Accordion("Parameters", open=False) as parameter_row: temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0, step=0.1, interactive=True, label="Temperature",) top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) with gr.Column(scale=8): chatbot = gr.Chatbot( elem_id="chatbot", label="LLaVA Chatbot", height=650, layout="panel", ) with gr.Row(): with gr.Column(scale=8): textbox.render() with gr.Column(scale=1, min_width=50): submit_btn = gr.Button(value="Send", variant="primary") with gr.Row(elem_id="buttons") as button_row: upvote_btn = gr.Button(value="👍 Upvote", interactive=False) downvote_btn = gr.Button(value="👎 Downvote", interactive=False) flag_btn = gr.Button(value="⚠ī¸ Flag", interactive=False) #stop_btn = gr.Button(value="⏚ī¸ Stop Generation", interactive=False) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) clear_btn = gr.Button(value="🗑ī¸ Clear", interactive=False) if not embed_mode: gr.Markdown(gws.tos_markdown) gr.Markdown(gws.learn_more_markdown) url_params = gr.JSON(visible=False) # Register listeners btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] upvote_btn.click( gws.upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) downvote_btn.click( gws.downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) flag_btn.click( gws.flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn] ) regenerate_btn.click( gws.regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then( gws.http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) clear_btn.click( gws.clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False ) textbox.submit( gws.add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False ).then( gws.http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) submit_btn.click( gws.add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list ).then( gws.http_bot, [state, model_selector, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, concurrency_limit=concurrency_count ) if gws.args.model_list_mode == "once": demo.load( gws.load_demo, [url_params], [state, model_selector], js=gws.get_window_url_params ) elif gws.args.model_list_mode == "reload": demo.load( gws.load_demo_refresh_model_list, None, [state, model_selector], queue=False ) else: raise ValueError(f"Unknown model list mode: {gws.args.model_list_mode}") return demo # Execute the pip install command with additional options subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U']) def start_controller(): print("Starting the controller") controller_command = [ sys.executable, "-m", "llava.serve.controller", "--host", "0.0.0.0", "--port", "10000", ] print(controller_command) return subprocess.Popen(controller_command) def start_worker(model_path: str, bits=16): print(f"Starting the model worker for the model {model_path}") model_name = model_path.strip("/").split("/")[-1] assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit." if bits != 16: model_name += f"-{bits}bit" worker_command = [ sys.executable, "-m", "llava.serve.model_worker", "--host", "0.0.0.0", "--controller", "http://localhost:10000", "--model-path", model_path, "--model-name", model_name, "--use-flash-attn", ] if bits != 16: worker_command += [f"--load-{bits}bit"] print(worker_command) return subprocess.Popen(worker_command) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int) parser.add_argument("--controller-url", type=str, default="http://localhost:10000") parser.add_argument("--concurrency-count", type=int, default=5) parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"]) parser.add_argument("--share", action="store_true") parser.add_argument("--moderate", action="store_true") parser.add_argument("--embed", action="store_true") gws.args = parser.parse_args() gws.models = [] gws.title_markdown += """ ONLY WORKS WITH GPU! By default, we load the model with 4-bit quantization to make it fit in smaller hardwares. Set the environment variable `bits` to control the quantization. Set the environment variable `model` to change the model: [`liuhaotian/llava-v1.6-mistral-7b`](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b), [`liuhaotian/llava-v1.6-vicuna-7b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b), [`liuhaotian/llava-v1.6-vicuna-13b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-13b), [`liuhaotian/llava-v1.6-34b`](https://huggingface.co/liuhaotian/llava-v1.6-34b). """ print(f"args: {gws.args}") model_path = os.getenv("model", "liuhaotian/llava-v1.6-mistral-7b") bits = int(os.getenv("bits", 4)) concurrency_count = int(os.getenv("concurrency_count", 5)) controller_proc = start_controller() worker_proc = start_worker(model_path, bits=bits) # Wait for worker and controller to start time.sleep(10) exit_status = 0 try: demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count) demo.queue( status_update_rate=10, api_open=False ).launch( server_name=gws.args.host, server_port=gws.args.port, share=gws.args.share ) except Exception as e: print(e) exit_status = 1 finally: worker_proc.kill() controller_proc.kill() sys.exit(exit_status)