|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import json |
|
import time |
|
import os |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
from transformers import CLIPVisionModel, CLIPImageProcessor |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
print("π Starting LLaVA deployment...") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"π» Using device: {device}") |
|
|
|
|
|
tokenizer = None |
|
model = None |
|
image_processor = None |
|
vision_tower = None |
|
|
|
def load_model(): |
|
"""Load LLaVA model components""" |
|
global tokenizer, model, image_processor, vision_tower |
|
|
|
try: |
|
print("π¦ Loading tokenizer...") |
|
|
|
model_path = "liuhaotian/llava-v1.5-7b" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
print("π§ Loading language model...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map="auto" if device == "cuda" else None |
|
) |
|
|
|
print("ποΈ Loading vision components...") |
|
|
|
vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336") |
|
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336") |
|
|
|
if device == "cuda": |
|
vision_tower = vision_tower.to(device) |
|
|
|
print("β
Model loaded successfully!") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"β Error loading model: {str(e)}") |
|
return False |
|
|
|
def process_image(image): |
|
"""Process image for the model""" |
|
if image is None: |
|
return None |
|
|
|
try: |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'] |
|
|
|
if device == "cuda": |
|
image_tensor = image_tensor.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
image_features = vision_tower(image_tensor).last_hidden_state |
|
|
|
return image_features |
|
|
|
except Exception as e: |
|
print(f"Error processing image: {str(e)}") |
|
return None |
|
|
|
def generate_response(message, image=None, system_prompt="", max_tokens=1024, temperature=0.7): |
|
"""Generate response using LLaVA""" |
|
global tokenizer, model, image_processor, vision_tower |
|
|
|
if model is None: |
|
return "β Model not loaded. Please wait for initialization." |
|
|
|
try: |
|
|
|
image_features = None |
|
if image is not None: |
|
image_features = process_image(image) |
|
if image_features is None: |
|
return "β Error processing image." |
|
|
|
|
|
if system_prompt: |
|
full_prompt = f"System: {system_prompt}\n\nUser: {message}\n\nAssistant:" |
|
else: |
|
if image is not None: |
|
full_prompt = f"USER: <image>\n{message}\nASSISTANT:" |
|
else: |
|
full_prompt = f"USER: {message}\nASSISTANT:" |
|
|
|
|
|
inputs = tokenizer(full_prompt, return_tensors="pt") |
|
|
|
if device == "cuda": |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
if image_features is not None: |
|
|
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
else: |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = response[len(full_prompt):].strip() |
|
|
|
return response |
|
|
|
except Exception as e: |
|
return f"β Error generating response: {str(e)}" |
|
|
|
def api_endpoint(request_json): |
|
"""API endpoint for programmatic access""" |
|
try: |
|
data = json.loads(request_json) |
|
|
|
message = data.get("message", "") |
|
system_prompt = data.get("system_prompt", "") |
|
image_url = data.get("image_url", None) |
|
max_tokens = int(data.get("max_tokens", 1024)) |
|
temperature = float(data.get("temperature", 0.7)) |
|
|
|
|
|
image = None |
|
if image_url: |
|
try: |
|
response = requests.get(image_url, timeout=10) |
|
if response.status_code == 200: |
|
image = Image.open(BytesIO(response.content)) |
|
except Exception as e: |
|
return json.dumps({"error": f"Failed to load image: {str(e)}"}) |
|
|
|
|
|
response_text = generate_response( |
|
message=message, |
|
image=image, |
|
system_prompt=system_prompt, |
|
max_tokens=max_tokens, |
|
temperature=temperature |
|
) |
|
|
|
|
|
return json.dumps({ |
|
"id": f"chatcmpl-{int(time.time())}", |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": "llava-v1.5-7b", |
|
"choices": [{ |
|
"message": { |
|
"role": "assistant", |
|
"content": response_text |
|
}, |
|
"index": 0, |
|
"finish_reason": "stop" |
|
}], |
|
"usage": { |
|
"prompt_tokens": 0, |
|
"completion_tokens": 0, |
|
"total_tokens": 0 |
|
} |
|
}) |
|
|
|
except Exception as e: |
|
return json.dumps({"error": str(e)}) |
|
|
|
|
|
print("π Initializing model...") |
|
model_loaded = load_model() |
|
|
|
|
|
with gr.Blocks(title="LLaVA - Large Language and Vision Assistant", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π¦ LLaVA - Large Language and Vision Assistant |
|
|
|
An open-source chatbot trained by fine-tuning LLaMA/Vicuna on GPT-generated multimodal instruction-following data. |
|
|
|
**Features:** |
|
- π¬ Text-based conversation |
|
- πΌοΈ Image understanding and description |
|
- π§ API endpoint for integration |
|
""") |
|
|
|
with gr.Tab("π¬ Chat Interface"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image( |
|
type="pil", |
|
label="πΈ Upload Image (Optional)", |
|
height=300 |
|
) |
|
system_prompt = gr.Textbox( |
|
label="π― System Prompt (Optional)", |
|
placeholder="You are a helpful assistant that can analyze images...", |
|
lines=2 |
|
) |
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
label="π Conversation", |
|
height=400 |
|
) |
|
|
|
msg = gr.Textbox( |
|
label="βοΈ Your Message", |
|
placeholder="Type your message here... You can ask about the uploaded image!", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("π Send", variant="primary") |
|
clear_btn = gr.Button("ποΈ Clear", variant="secondary") |
|
|
|
with gr.Accordion("βοΈ Advanced Settings", open=False): |
|
max_tokens = gr.Slider( |
|
minimum=1, |
|
maximum=2048, |
|
value=1024, |
|
step=1, |
|
label="π Max Tokens" |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.7, |
|
step=0.1, |
|
label="π‘οΈ Temperature" |
|
) |
|
|
|
with gr.Tab("π API Documentation"): |
|
gr.Markdown(""" |
|
## API Endpoint Usage |
|
|
|
**Endpoint**: `https://your-space-name.hf.space/api/predict` |
|
|
|
**Method**: POST |
|
|
|
### Request Format: |
|
```json |
|
{ |
|
"data": [ |
|
"{ |
|
\"message\": \"Describe this image in detail\", |
|
\"system_prompt\": \"You are a helpful assistant\", |
|
\"image_url\": \"https://example.com/image.jpg\", |
|
\"max_tokens\": 1024, |
|
\"temperature\": 0.7 |
|
}" |
|
] |
|
} |
|
``` |
|
|
|
### Response Format: |
|
```json |
|
{ |
|
"data": [ |
|
"{ |
|
\"id\": \"chatcmpl-123456789\", |
|
\"object\": \"chat.completion\", |
|
\"created\": 1683123456, |
|
\"model\": \"llava-v1.5-7b\", |
|
\"choices\": [ |
|
{ |
|
\"message\": { |
|
\"role\": \"assistant\", |
|
\"content\": \"This image shows...\" |
|
}, |
|
\"index\": 0, |
|
\"finish_reason\": \"stop\" |
|
} |
|
] |
|
}" |
|
] |
|
} |
|
``` |
|
|
|
### Python Client Example: |
|
```python |
|
import requests |
|
import json |
|
|
|
def query_llava(message, image_url=None, system_prompt=""): |
|
payload = { |
|
"data": [json.dumps({ |
|
"message": message, |
|
"image_url": image_url, |
|
"system_prompt": system_prompt, |
|
"max_tokens": 1024, |
|
"temperature": 0.7 |
|
})] |
|
} |
|
|
|
response = requests.post( |
|
"https://your-space-name.hf.space/api/predict", |
|
json=payload |
|
) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
api_response = json.loads(result["data"][0]) |
|
return api_response["choices"][0]["message"]["content"] |
|
else: |
|
return f"Error: {response.status_code}" |
|
|
|
# Example usage |
|
result = query_llava( |
|
"What do you see in this image?", |
|
image_url="https://example.com/image.jpg" |
|
) |
|
print(result) |
|
``` |
|
""") |
|
|
|
|
|
gr.Markdown("### π§ͺ Test API") |
|
api_input = gr.Textbox( |
|
label="π API Request (JSON)", |
|
placeholder='{"message": "Hello!", "max_tokens": 1024}', |
|
lines=4 |
|
) |
|
api_output = gr.Textbox( |
|
label="π€ API Response", |
|
lines=8 |
|
) |
|
api_test_btn = gr.Button("π§ͺ Test API", variant="primary") |
|
|
|
with gr.Tab("βΉοΈ About"): |
|
gr.Markdown(""" |
|
## About LLaVA |
|
|
|
**LLaVA (Large Language and Vision Assistant)** is an open-source multimodal AI assistant that combines: |
|
|
|
- π§ **Language Understanding**: Based on Vicuna/LLaMA architecture |
|
- ποΈ **Vision Capabilities**: Uses CLIP vision encoder |
|
- π **Multimodal Integration**: Connects vision and language seamlessly |
|
|
|
### Key Features: |
|
- **Visual Question Answering**: Ask questions about images |
|
- **Image Description**: Get detailed descriptions of uploaded images |
|
- **General Conversation**: Chat about any topic |
|
- **API Integration**: Easy integration with your applications |
|
|
|
### Model Information: |
|
- **Base Model**: LLaVA-v1.5-7B |
|
- **Vision Encoder**: CLIP ViT-L/14@336px |
|
- **Language Model**: Vicuna-7B |
|
- **Training Data**: LLaVA-Instruct-150K |
|
|
|
### Citation: |
|
``` |
|
@misc{liu2023llava, |
|
title={Visual Instruction Tuning}, |
|
author={Haotian Liu and Chunyuan Li and Qingyang Wu and Yong Jae Lee}, |
|
year={2023}, |
|
eprint={2304.08485}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CV} |
|
} |
|
``` |
|
|
|
**GitHub**: [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA) |
|
""") |
|
|
|
|
|
def respond(message, chat_history, image, system_prompt, max_tokens, temperature): |
|
if not message.strip(): |
|
return "", chat_history |
|
|
|
|
|
chat_history.append([message, None]) |
|
|
|
|
|
response = generate_response( |
|
message=message, |
|
image=image, |
|
system_prompt=system_prompt if system_prompt.strip() else "", |
|
max_tokens=int(max_tokens), |
|
temperature=temperature |
|
) |
|
|
|
|
|
chat_history[-1][1] = response |
|
|
|
return "", chat_history |
|
|
|
def clear_chat(): |
|
return None, [] |
|
|
|
|
|
submit_btn.click( |
|
respond, |
|
[msg, chatbot, image_input, system_prompt, max_tokens, temperature], |
|
[msg, chatbot] |
|
) |
|
|
|
msg.submit( |
|
respond, |
|
[msg, chatbot, image_input, system_prompt, max_tokens, temperature], |
|
[msg, chatbot] |
|
) |
|
|
|
clear_btn.click(clear_chat, outputs=[chatbot, msg]) |
|
|
|
api_test_btn.click(api_endpoint, inputs=api_input, outputs=api_output) |
|
|
|
|
|
api_interface = gr.Interface( |
|
fn=api_endpoint, |
|
inputs=gr.Textbox(), |
|
outputs=gr.Textbox(), |
|
api_name="predict" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |