llava-api / app.py
grayphite's picture
Update app.py
13b7b20 verified
import gradio as gr
import torch
from PIL import Image
import requests
from io import BytesIO
import json
import time
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import CLIPVisionModel, CLIPImageProcessor
import warnings
warnings.filterwarnings("ignore")
print("πŸš€ Starting LLaVA deployment...")
# Check GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸ’» Using device: {device}")
# Global variables for model components
tokenizer = None
model = None
image_processor = None
vision_tower = None
def load_model():
"""Load LLaVA model components"""
global tokenizer, model, image_processor, vision_tower
try:
print("πŸ“¦ Loading tokenizer...")
# Use the smaller 7B model for free tier
model_path = "liuhaotian/llava-v1.5-7b"
tokenizer = AutoTokenizer.from_pretrained(model_path)
print("🧠 Loading language model...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True,
device_map="auto" if device == "cuda" else None
)
print("πŸ‘οΈ Loading vision components...")
# Load vision tower
vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
if device == "cuda":
vision_tower = vision_tower.to(device)
print("βœ… Model loaded successfully!")
return True
except Exception as e:
print(f"❌ Error loading model: {str(e)}")
return False
def process_image(image):
"""Process image for the model"""
if image is None:
return None
try:
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Process image
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']
if device == "cuda":
image_tensor = image_tensor.to(device)
# Get image features
with torch.no_grad():
image_features = vision_tower(image_tensor).last_hidden_state
return image_features
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
def generate_response(message, image=None, system_prompt="", max_tokens=1024, temperature=0.7):
"""Generate response using LLaVA"""
global tokenizer, model, image_processor, vision_tower
if model is None:
return "❌ Model not loaded. Please wait for initialization."
try:
# Process image if provided
image_features = None
if image is not None:
image_features = process_image(image)
if image_features is None:
return "❌ Error processing image."
# Prepare prompt
if system_prompt:
full_prompt = f"System: {system_prompt}\n\nUser: {message}\n\nAssistant:"
else:
if image is not None:
full_prompt = f"USER: <image>\n{message}\nASSISTANT:"
else:
full_prompt = f"USER: {message}\nASSISTANT:"
# Tokenize
inputs = tokenizer(full_prompt, return_tensors="pt")
if device == "cuda":
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
if image_features is not None:
# For multimodal input, we need to handle image features
# This is a simplified version - real LLaVA has more complex integration
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
else:
# Text-only generation
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up response (remove the input prompt)
response = response[len(full_prompt):].strip()
return response
except Exception as e:
return f"❌ Error generating response: {str(e)}"
def api_endpoint(request_json):
"""API endpoint for programmatic access"""
try:
data = json.loads(request_json)
message = data.get("message", "")
system_prompt = data.get("system_prompt", "")
image_url = data.get("image_url", None)
max_tokens = int(data.get("max_tokens", 1024))
temperature = float(data.get("temperature", 0.7))
# Process image if URL provided
image = None
if image_url:
try:
response = requests.get(image_url, timeout=10)
if response.status_code == 200:
image = Image.open(BytesIO(response.content))
except Exception as e:
return json.dumps({"error": f"Failed to load image: {str(e)}"})
# Generate response
response_text = generate_response(
message=message,
image=image,
system_prompt=system_prompt,
max_tokens=max_tokens,
temperature=temperature
)
# Return API response
return json.dumps({
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": "llava-v1.5-7b",
"choices": [{
"message": {
"role": "assistant",
"content": response_text
},
"index": 0,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 0, # Simplified
"completion_tokens": 0, # Simplified
"total_tokens": 0 # Simplified
}
})
except Exception as e:
return json.dumps({"error": str(e)})
# Initialize model on startup
print("πŸ”„ Initializing model...")
model_loaded = load_model()
# Create Gradio interface
with gr.Blocks(title="LLaVA - Large Language and Vision Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ¦™ LLaVA - Large Language and Vision Assistant
An open-source chatbot trained by fine-tuning LLaMA/Vicuna on GPT-generated multimodal instruction-following data.
**Features:**
- πŸ’¬ Text-based conversation
- πŸ–ΌοΈ Image understanding and description
- πŸ”§ API endpoint for integration
""")
with gr.Tab("πŸ’¬ Chat Interface"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="πŸ“Έ Upload Image (Optional)",
height=300
)
system_prompt = gr.Textbox(
label="🎯 System Prompt (Optional)",
placeholder="You are a helpful assistant that can analyze images...",
lines=2
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="πŸ’­ Conversation",
height=400
)
msg = gr.Textbox(
label="✍️ Your Message",
placeholder="Type your message here... You can ask about the uploaded image!",
lines=2
)
with gr.Row():
submit_btn = gr.Button("πŸš€ Send", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=1024,
step=1,
label="πŸ“ Max Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="🌑️ Temperature"
)
with gr.Tab("πŸ”Œ API Documentation"):
gr.Markdown("""
## API Endpoint Usage
**Endpoint**: `https://your-space-name.hf.space/api/predict`
**Method**: POST
### Request Format:
```json
{
"data": [
"{
\"message\": \"Describe this image in detail\",
\"system_prompt\": \"You are a helpful assistant\",
\"image_url\": \"https://example.com/image.jpg\",
\"max_tokens\": 1024,
\"temperature\": 0.7
}"
]
}
```
### Response Format:
```json
{
"data": [
"{
\"id\": \"chatcmpl-123456789\",
\"object\": \"chat.completion\",
\"created\": 1683123456,
\"model\": \"llava-v1.5-7b\",
\"choices\": [
{
\"message\": {
\"role\": \"assistant\",
\"content\": \"This image shows...\"
},
\"index\": 0,
\"finish_reason\": \"stop\"
}
]
}"
]
}
```
### Python Client Example:
```python
import requests
import json
def query_llava(message, image_url=None, system_prompt=""):
payload = {
"data": [json.dumps({
"message": message,
"image_url": image_url,
"system_prompt": system_prompt,
"max_tokens": 1024,
"temperature": 0.7
})]
}
response = requests.post(
"https://your-space-name.hf.space/api/predict",
json=payload
)
if response.status_code == 200:
result = response.json()
api_response = json.loads(result["data"][0])
return api_response["choices"][0]["message"]["content"]
else:
return f"Error: {response.status_code}"
# Example usage
result = query_llava(
"What do you see in this image?",
image_url="https://example.com/image.jpg"
)
print(result)
```
""")
# API testing interface
gr.Markdown("### πŸ§ͺ Test API")
api_input = gr.Textbox(
label="πŸ“ API Request (JSON)",
placeholder='{"message": "Hello!", "max_tokens": 1024}',
lines=4
)
api_output = gr.Textbox(
label="πŸ“€ API Response",
lines=8
)
api_test_btn = gr.Button("πŸ§ͺ Test API", variant="primary")
with gr.Tab("ℹ️ About"):
gr.Markdown("""
## About LLaVA
**LLaVA (Large Language and Vision Assistant)** is an open-source multimodal AI assistant that combines:
- 🧠 **Language Understanding**: Based on Vicuna/LLaMA architecture
- πŸ‘οΈ **Vision Capabilities**: Uses CLIP vision encoder
- πŸ”— **Multimodal Integration**: Connects vision and language seamlessly
### Key Features:
- **Visual Question Answering**: Ask questions about images
- **Image Description**: Get detailed descriptions of uploaded images
- **General Conversation**: Chat about any topic
- **API Integration**: Easy integration with your applications
### Model Information:
- **Base Model**: LLaVA-v1.5-7B
- **Vision Encoder**: CLIP ViT-L/14@336px
- **Language Model**: Vicuna-7B
- **Training Data**: LLaVA-Instruct-150K
### Citation:
```
@misc{liu2023llava,
title={Visual Instruction Tuning},
author={Haotian Liu and Chunyuan Li and Qingyang Wu and Yong Jae Lee},
year={2023},
eprint={2304.08485},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
**GitHub**: [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA)
""")
# Event handlers
def respond(message, chat_history, image, system_prompt, max_tokens, temperature):
if not message.strip():
return "", chat_history
# Add user message to chat
chat_history.append([message, None])
# Generate response
response = generate_response(
message=message,
image=image,
system_prompt=system_prompt if system_prompt.strip() else "",
max_tokens=int(max_tokens),
temperature=temperature
)
# Add assistant response to chat
chat_history[-1][1] = response
return "", chat_history
def clear_chat():
return None, []
# Connect event handlers
submit_btn.click(
respond,
[msg, chatbot, image_input, system_prompt, max_tokens, temperature],
[msg, chatbot]
)
msg.submit(
respond,
[msg, chatbot, image_input, system_prompt, max_tokens, temperature],
[msg, chatbot]
)
clear_btn.click(clear_chat, outputs=[chatbot, msg])
api_test_btn.click(api_endpoint, inputs=api_input, outputs=api_output)
# Add API endpoint
api_interface = gr.Interface(
fn=api_endpoint,
inputs=gr.Textbox(),
outputs=gr.Textbox(),
api_name="predict"
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)