File size: 7,069 Bytes
d1b18a4
 
 
 
 
 
 
 
ee45c78
d1b18a4
 
 
8dc4452
d1b18a4
 
 
 
 
 
 
 
 
 
 
 
 
 
ee45c78
d1b18a4
 
 
 
 
 
 
 
 
b56f08b
d1b18a4
 
 
 
 
 
 
 
 
 
 
 
b56f08b
dc169a4
d1b18a4
 
ee45c78
d1b18a4
 
 
ee45c78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1b18a4
ee45c78
 
 
 
 
 
 
 
 
 
 
d1b18a4
 
 
 
ee45c78
d1b18a4
80eb574
d1b18a4
 
 
 
 
 
 
ee45c78
163c9f2
d1b18a4
 
 
ee45c78
 
 
d1b18a4
 
 
 
ee45c78
 
 
d1b18a4
 
 
ee45c78
d1b18a4
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# Based on liuhaotian/LLaVA-1.6

import sys
import os
import argparse
import time
import subprocess

import gradio as gr
import llava.serve.gradio_web_server as gws

# Execute the pip install command with additional options
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'wheel', 'setuptools'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])


def start_controller():
    print("Starting the controller")
    controller_command = [
        sys.executable,
        "-m",
        "llava.serve.controller",
        "--host",
        "0.0.0.0",
        "--port",
        "10000",
    ]
    print("Controller Command:", controller_command)
    return subprocess.Popen(controller_command)


def start_worker(model_path: str, bits=16):
    print(f"Starting the model worker for the model {model_path}")
    model_name = model_path.strip("/").split("/")[-1]
    assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
    if bits != 16:
        model_name += f"-{bits}bit"
    model_name += "-lora"
    worker_command = [
        sys.executable,
        "-m",
        "llava.serve.model_worker",
        "--host",
        "0.0.0.0",
        "--controller",
        "http://localhost:10000",
        "--model-path",
        model_path,
        "--model-name",
        model_name,
        "--model-base",
        "liuhaotian/llava-1.5-7b",
        "--use-flash-attn",
    ]
    print("Worker Command:", worker_command)
    return subprocess.Popen(worker_command)


def handle_text_prompt(text, temperature=0.2, top_p=0.7, max_new_tokens=512):
    """
    Custom API endpoint to handle text prompts.
    Replace the placeholder logic with actual model inference.
    """
    # TODO: Replace the following placeholder with actual model inference code
    print(f"Received prompt: {text}")
    print(f"Parameters - Temperature: {temperature}, Top P: {top_p}, Max New Tokens: {max_new_tokens}")
    
    # Example response (replace with actual model response)
    response = f"Model response to '{text}' with temperature={temperature}, top_p={top_p}, max_new_tokens={max_new_tokens}"
    return response


def add_text_with_image(text, image, mode):
    """
    Custom API endpoint to add text with an image.
    Replace the placeholder logic with actual processing.
    """
    # TODO: Replace the following placeholder with actual processing code
    print(f"Adding text: {text}")
    print(f"Image path: {image}")
    print(f"Image processing mode: {mode}")
    
    # Example response (replace with actual processing code)
    response = f"Added text '{text}' with image at '{image}' using mode '{mode}'."
    return response


def build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=5):
    """
    Builds a Gradio Blocks interface with custom API endpoints.
    """
    with gr.Blocks() as demo:
        gr.Markdown("# AstroLLaVA")
        gr.Markdown("Welcome to the AstroLLaVA interface. Use the API endpoints to interact with the model.")

        with gr.Row():
            with gr.Column():
                gr.Markdown("## Prompt the Model")
                text_input = gr.Textbox(label="Enter your text prompt", placeholder="Type your prompt here...")
                temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature")
                top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Top P")
                max_tokens_slider = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max New Tokens")
                submit_button = gr.Button("Submit Prompt")
            with gr.Column():
                chatbot_output = gr.Textbox(label="Model Response", interactive=False)

        submit_button.click(
            fn=handle_text_prompt,
            inputs=[text_input, temperature_slider, top_p_slider, max_tokens_slider],
            outputs=chatbot_output,
            api_name="prompt_model"  # Custom API endpoint name
        )

        with gr.Row():
            with gr.Column():
                gr.Markdown("## Add Text with Image")
                add_text_input = gr.Textbox(label="Add Text", placeholder="Enter text to add...")
                add_image_input = gr.Image(label="Upload Image")
                image_process_mode = gr.Radio(choices=["Crop", "Resize", "Pad", "Default"], value="Default", label="Image Process Mode")
                add_submit_button = gr.Button("Add Text with Image")
            with gr.Column():
                add_output = gr.Textbox(label="Add Text Response", interactive=False)

        add_submit_button.click(
            fn=add_text_with_image,
            inputs=[add_text_input, add_image_input, image_process_mode],
            outputs=add_output,
            api_name="add_text_with_image"  # Another custom API endpoint
        )

        # Additional API endpoints can be added here following the same structure

    return demo


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="AstroLLaVA Gradio App")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="Hostname to listen on")
    parser.add_argument("--port", type=int, default=7860, help="Port number")
    parser.add_argument("--controller-url", type=str, default="http://localhost:10000", help="Controller URL")
    parser.add_argument("--concurrency-count", type=int, default=5, help="Number of concurrent requests")
    parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"], help="Model list mode")
    parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly")
    parser.add_argument("--moderate", action="store_true", help="Enable moderation")
    parser.add_argument("--embed", action="store_true", help="Enable embed mode")
    args = parser.parse_args()
    gws.args = args
    gws.models = []

    gws.title_markdown += """ AstroLLaVA """

    print(f"AstroLLaVA arguments: {gws.args}")

    model_path = os.getenv("model", "universeTBD/AstroLLaVA_v2")
    bits = int(os.getenv("bits", 4))
    concurrency_count = int(os.getenv("concurrency_count", 5))

    controller_proc = start_controller()
    worker_proc = start_worker(model_path, bits=bits)

    # Wait for worker and controller to start
    print("Waiting for worker and controller to start...")
    time.sleep(30)

    exit_status = 0
    try:
        # Build the custom Gradio demo with additional API endpoints
        demo = build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
        print("Launching Gradio with custom API endpoints...")
        demo.queue(
            status_update_rate=10,
            api_open=False
        ).launch(
            server_name=args.host,
            server_port=args.port,
            share=args.share
        )

    except Exception as e:
        print(f"An error occurred: {e}")
        exit_status = 1
    finally:
        worker_proc.kill()
        controller_proc.kill()

        sys.exit(exit_status)