llava / app.py
grayphite's picture
Update app.py
9f92e1f verified
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from io import BytesIO
import requests
import json
import time
# Load processor and model
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.float16,
device_map="auto"
)
# Core inference function
def generate_response(user_message, system_prompt=None, image=None, max_tokens=1024, temperature=0.7):
if system_prompt:
prompt = f"<image>\n{system_prompt}\n{user_message}"
else:
prompt = f"<image>\n{user_message}"
inputs = processor(prompt, image, return_tensors="pt").to(model.device)
with torch.inference_mode():
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
)
response_text = processor.decode(output[0], skip_special_tokens=True)
return response_text
# API-style function for programmatic access
def api_endpoint(request: gr.Request):
try:
data = request.json
user_message = data.get("user_message", "")
system_prompt = data.get("system_prompt", None)
image_url = data.get("image_url", None)
max_tokens = data.get("max_tokens", 1024)
temperature = data.get("temperature", 0.7)
image_data = None
if image_url:
image_response = requests.get(image_url)
image_data = Image.open(BytesIO(image_response.content)).convert("RGB")
response_text = generate_response(
user_message=user_message,
system_prompt=system_prompt,
image=image_data,
max_tokens=max_tokens,
temperature=temperature
)
return gr.Response(json.dumps({
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": "llava-1.5-7b",
"choices": [{
"message": {
"role": "assistant",
"content": response_text
},
"index": 0,
"finish_reason": "stop"
}]
}), media_type="application/json")
except Exception as e:
return gr.Response(json.dumps({"error": str(e)}), media_type="application/json")
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# πŸ” LLaVA API Demo")
with gr.Tab("Test UI"):
with gr.Row():
with gr.Column():
user_message = gr.Textbox(label="User Message", lines=3)
system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=2)
image_input = gr.Image(label="Image (Optional)", type="pil")
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=2048, value=1024, step=1)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1)
submit_btn = gr.Button("Generate Response")
with gr.Column():
output = gr.Textbox(label="Response", lines=10)
def on_submit(message, system, image, tokens, temp):
return generate_response(message, system, image, tokens, temp)
submit_btn.click(
fn=on_submit,
inputs=[user_message, system_prompt, image_input, max_tokens, temperature],
outputs=output
)
# API endpoint
demo.api("/api")(api_endpoint)
# Launch
demo.launch()