Spaces:
Build error
Build error
File size: 7,069 Bytes
d1b18a4 ee45c78 d1b18a4 8dc4452 d1b18a4 ee45c78 d1b18a4 b56f08b d1b18a4 b56f08b dc169a4 d1b18a4 ee45c78 d1b18a4 ee45c78 d1b18a4 ee45c78 d1b18a4 ee45c78 d1b18a4 80eb574 d1b18a4 ee45c78 163c9f2 d1b18a4 ee45c78 d1b18a4 ee45c78 d1b18a4 ee45c78 d1b18a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# Based on liuhaotian/LLaVA-1.6
import sys
import os
import argparse
import time
import subprocess
import gradio as gr
import llava.serve.gradio_web_server as gws
# Execute the pip install command with additional options
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'wheel', 'setuptools'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
def start_controller():
print("Starting the controller")
controller_command = [
sys.executable,
"-m",
"llava.serve.controller",
"--host",
"0.0.0.0",
"--port",
"10000",
]
print("Controller Command:", controller_command)
return subprocess.Popen(controller_command)
def start_worker(model_path: str, bits=16):
print(f"Starting the model worker for the model {model_path}")
model_name = model_path.strip("/").split("/")[-1]
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
if bits != 16:
model_name += f"-{bits}bit"
model_name += "-lora"
worker_command = [
sys.executable,
"-m",
"llava.serve.model_worker",
"--host",
"0.0.0.0",
"--controller",
"http://localhost:10000",
"--model-path",
model_path,
"--model-name",
model_name,
"--model-base",
"liuhaotian/llava-1.5-7b",
"--use-flash-attn",
]
print("Worker Command:", worker_command)
return subprocess.Popen(worker_command)
def handle_text_prompt(text, temperature=0.2, top_p=0.7, max_new_tokens=512):
"""
Custom API endpoint to handle text prompts.
Replace the placeholder logic with actual model inference.
"""
# TODO: Replace the following placeholder with actual model inference code
print(f"Received prompt: {text}")
print(f"Parameters - Temperature: {temperature}, Top P: {top_p}, Max New Tokens: {max_new_tokens}")
# Example response (replace with actual model response)
response = f"Model response to '{text}' with temperature={temperature}, top_p={top_p}, max_new_tokens={max_new_tokens}"
return response
def add_text_with_image(text, image, mode):
"""
Custom API endpoint to add text with an image.
Replace the placeholder logic with actual processing.
"""
# TODO: Replace the following placeholder with actual processing code
print(f"Adding text: {text}")
print(f"Image path: {image}")
print(f"Image processing mode: {mode}")
# Example response (replace with actual processing code)
response = f"Added text '{text}' with image at '{image}' using mode '{mode}'."
return response
def build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=5):
"""
Builds a Gradio Blocks interface with custom API endpoints.
"""
with gr.Blocks() as demo:
gr.Markdown("# AstroLLaVA")
gr.Markdown("Welcome to the AstroLLaVA interface. Use the API endpoints to interact with the model.")
with gr.Row():
with gr.Column():
gr.Markdown("## Prompt the Model")
text_input = gr.Textbox(label="Enter your text prompt", placeholder="Type your prompt here...")
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Top P")
max_tokens_slider = gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max New Tokens")
submit_button = gr.Button("Submit Prompt")
with gr.Column():
chatbot_output = gr.Textbox(label="Model Response", interactive=False)
submit_button.click(
fn=handle_text_prompt,
inputs=[text_input, temperature_slider, top_p_slider, max_tokens_slider],
outputs=chatbot_output,
api_name="prompt_model" # Custom API endpoint name
)
with gr.Row():
with gr.Column():
gr.Markdown("## Add Text with Image")
add_text_input = gr.Textbox(label="Add Text", placeholder="Enter text to add...")
add_image_input = gr.Image(label="Upload Image")
image_process_mode = gr.Radio(choices=["Crop", "Resize", "Pad", "Default"], value="Default", label="Image Process Mode")
add_submit_button = gr.Button("Add Text with Image")
with gr.Column():
add_output = gr.Textbox(label="Add Text Response", interactive=False)
add_submit_button.click(
fn=add_text_with_image,
inputs=[add_text_input, add_image_input, image_process_mode],
outputs=add_output,
api_name="add_text_with_image" # Another custom API endpoint
)
# Additional API endpoints can be added here following the same structure
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AstroLLaVA Gradio App")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Hostname to listen on")
parser.add_argument("--port", type=int, default=7860, help="Port number")
parser.add_argument("--controller-url", type=str, default="http://localhost:10000", help="Controller URL")
parser.add_argument("--concurrency-count", type=int, default=5, help="Number of concurrent requests")
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"], help="Model list mode")
parser.add_argument("--share", action="store_true", help="Share the Gradio app publicly")
parser.add_argument("--moderate", action="store_true", help="Enable moderation")
parser.add_argument("--embed", action="store_true", help="Enable embed mode")
args = parser.parse_args()
gws.args = args
gws.models = []
gws.title_markdown += """ AstroLLaVA """
print(f"AstroLLaVA arguments: {gws.args}")
model_path = os.getenv("model", "universeTBD/AstroLLaVA_v2")
bits = int(os.getenv("bits", 4))
concurrency_count = int(os.getenv("concurrency_count", 5))
controller_proc = start_controller()
worker_proc = start_worker(model_path, bits=bits)
# Wait for worker and controller to start
print("Waiting for worker and controller to start...")
time.sleep(30)
exit_status = 0
try:
# Build the custom Gradio demo with additional API endpoints
demo = build_custom_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
print("Launching Gradio with custom API endpoints...")
demo.queue(
status_update_rate=10,
api_open=False
).launch(
server_name=args.host,
server_port=args.port,
share=args.share
)
except Exception as e:
print(f"An error occurred: {e}")
exit_status = 1
finally:
worker_proc.kill()
controller_proc.kill()
sys.exit(exit_status)
|