viscot-demo-2 / app.py
dung-vpt-uney
Deploy Visual-CoT demo with Zero GPU support
b90b5f6
"""
Visual-CoT: Chain-of-Thought Reasoning Demo on Hugging Face Spaces
Showcasing Visual Chain-of-Thought with Interactive Benchmark Examples
Paper: Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive
Dataset and Benchmark for Chain-of-Thought Reasoning
https://arxiv.org/abs/2403.16999
"""
import os
import torch
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import re
import json
import spaces
from pathlib import Path
import requests
from io import BytesIO
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
)
# =============================================================================
# Configuration
# =============================================================================
MODEL_PATH = "deepcs233/VisCoT-7b-336" # Hugging Face model ID
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Benchmark datasets available
BENCHMARK_DATASETS = [
"docvqa",
"flickr30k",
"gqa",
"infographicsvqa",
"openimages",
"textcap",
"textvqa",
"vsr",
"cub",
]
# Global model variables (lazy loading)
tokenizer, model, image_processor, context_len = None, None, None, None
# =============================================================================
# Model Loading (with Zero GPU optimization)
# =============================================================================
def load_model_once():
"""Load model once and cache it"""
global tokenizer, model, image_processor, context_len
if model is not None:
return tokenizer, model, image_processor, context_len
print("🔄 Loading Visual-CoT model...")
disable_torch_init()
model_name = get_model_name_from_path(MODEL_PATH)
tokenizer, model, image_processor, context_len = load_pretrained_model(
MODEL_PATH,
None,
model_name,
load_8bit=False,
load_4bit=False,
device=DEVICE,
)
print("✓ Model loaded successfully!")
return tokenizer, model, image_processor, context_len
# =============================================================================
# Utility Functions
# =============================================================================
def parse_bbox(text):
"""Parse bounding box from model output"""
pattern1 = r"###\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]"
pattern2 = r"\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]"
matches = re.findall(pattern1, text)
if not matches:
matches = re.findall(pattern2, text)
if matches:
bbox = [float(x) for x in matches[-1]]
if all(0 <= x <= 1 for x in bbox):
return bbox
return None
def draw_bounding_box(image, bbox, color="red", width=5):
"""Draw bounding box on image"""
if bbox is None:
return image
img = image.copy()
draw = ImageDraw.Draw(img)
img_width, img_height = img.size
# Convert normalized to pixel coordinates
x1 = int(bbox[0] * img_width)
y1 = int(bbox[1] * img_height)
x2 = int(bbox[2] * img_width)
y2 = int(bbox[3] * img_height)
# Draw rectangle
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
# Draw label
label = f"ROI: [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}]"
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14)
except:
font = ImageFont.load_default()
# Text background
bbox_text = draw.textbbox((x1, y1 - 22), label, font=font)
draw.rectangle([bbox_text[0]-2, bbox_text[1]-2, bbox_text[2]+2, bbox_text[3]+2], fill=color)
draw.text((x1, y1 - 22), label, fill="white", font=font)
return img
def load_benchmark_examples(dataset_name, num_examples=5):
"""
Load examples from benchmark dataset
Returns list of (image_path, question, ground_truth_bbox, ground_truth_answer)
"""
benchmark_file = f"viscot_benchmark/benchmark/{dataset_name}.json"
if not os.path.exists(benchmark_file):
return []
try:
with open(benchmark_file, 'r') as f:
data = json.load(f)
examples = []
for item in data[:num_examples]:
# Extract information based on dataset structure
image_file = item.get('image', '')
question = item['conversations'][0]['value'].replace('<image>\n', '').split('Please provide')[0].strip()
gt_bbox_str = item['conversations'][1]['value'] if len(item['conversations']) > 1 else None
gt_answer = item['conversations'][3]['value'] if len(item['conversations']) > 3 else None
examples.append({
'image': image_file,
'question': question,
'gt_bbox': gt_bbox_str,
'gt_answer': gt_answer,
'dataset': dataset_name
})
return examples
except Exception as e:
print(f"Error loading {dataset_name}: {e}")
return []
# =============================================================================
# Main Inference Function (with @spaces.GPU decorator)
# =============================================================================
@spaces.GPU(duration=120) # Zero GPU allocation for 120 seconds
def generate_viscot_response(image, question, temperature=0.2, max_tokens=512):
"""
Generate Visual-CoT response with bounding box detection
Args:
image: PIL Image
question: str
temperature: float
max_tokens: int
Returns:
tuple: (bbox_response, final_answer, image_with_bbox, processing_info)
"""
if image is None:
return "❌ Please upload an image!", "", None, ""
if not question.strip():
return "❌ Please enter a question!", "", None, ""
try:
# Load model (lazy loading)
tokenizer, model, image_processor, context_len = load_model_once()
# Initialize conversation
conv_mode = "llava_v1"
conv = conv_templates[conv_mode].copy()
# =====================================================================
# STEP 1: Detect Region of Interest (ROI)
# =====================================================================
prompt_step1 = (
f"{DEFAULT_IMAGE_TOKEN}\n{question} "
f"Please provide the bounding box coordinate of the region this question asks about."
)
conv.append_message(conv.roles[0], prompt_step1)
conv.append_message(conv.roles[1], None)
prompt1 = conv.get_prompt()
# Process image
image_tensor = process_images([image], image_processor, model.config)
if isinstance(image_tensor, list):
image_tensor = [img.to(DEVICE, dtype=torch.bfloat16) for img in image_tensor]
else:
image_tensor = image_tensor.to(DEVICE, dtype=torch.bfloat16)
# Tokenize
input_ids = tokenizer_image_token(
prompt1, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0).to(DEVICE)
# Generate bbox
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=temperature > 0.001,
temperature=max(temperature, 0.01),
max_new_tokens=128,
use_cache=True,
)
bbox_response = tokenizer.decode(
output_ids[0, input_ids.shape[1]:], skip_special_tokens=True
).strip()
# Parse bbox
bbox = parse_bbox(bbox_response)
# =====================================================================
# STEP 2: Answer Question with ROI Context
# =====================================================================
conv.messages[-1][-1] = bbox_response
second_question = (
f"Please answer the question based on the original image and local detail image. {question}"
)
conv.append_message(conv.roles[0], second_question)
conv.append_message(conv.roles[1], None)
prompt2 = conv.get_prompt()
input_ids = tokenizer_image_token(
prompt2, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0).to(DEVICE)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=temperature > 0.001,
temperature=max(temperature, 0.01),
max_new_tokens=max_tokens,
use_cache=True,
)
final_answer = tokenizer.decode(
output_ids[0, input_ids.shape[1]:], skip_special_tokens=True
).strip()
# Visualization
image_with_bbox = draw_bounding_box(image, bbox) if bbox else image
# Processing info
processing_info = f"✓ Processed successfully | Bbox: {bbox if bbox else 'Not detected'}"
return bbox_response, final_answer, image_with_bbox, processing_info
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
return error_msg, "", None, error_msg
# =============================================================================
# Gradio Interface
# =============================================================================
def create_demo():
"""Create Gradio interface"""
# Custom CSS for beautiful UI
custom_css = """
.gradio-container {
font-family: 'Inter', sans-serif;
}
.header {
text-align: center;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
margin-bottom: 20px;
}
.info-box {
background: #f0f7ff;
border-left: 4px solid #3b82f6;
padding: 15px;
border-radius: 5px;
margin: 10px 0;
}
.example-box {
border: 2px solid #e5e7eb;
border-radius: 8px;
padding: 10px;
margin: 5px 0;
}
.metric-card {
background: white;
border-radius: 8px;
padding: 15px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
margin: 10px 0;
}
"""
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="purple",
),
css=custom_css,
title="Visual-CoT Demo"
) as demo:
# Header
gr.HTML("""
<div class="header">
<h1>🌋 Visual-CoT: Chain-of-Thought Reasoning</h1>
<p style="font-size: 18px; margin: 10px 0;">
Advancing Multi-Modal Language Models with Visual Chain-of-Thought
</p>
<p style="font-size: 14px; opacity: 0.9;">
📄 <a href="https://arxiv.org/abs/2403.16999" style="color: white; text-decoration: underline;">
Paper (NeurIPS 2024 Spotlight)
</a> |
💻 <a href="https://github.com/deepcs233/Visual-CoT" style="color: white; text-decoration: underline;">
GitHub
</a> |
🤗 <a href="https://huggingface.co/datasets/deepcs233/Visual-CoT" style="color: white; text-decoration: underline;">
Dataset
</a>
</p>
</div>
""")
# Introduction
gr.Markdown("""
## 🎯 What is Visual-CoT?
**Visual Chain-of-Thought (VisCoT)** enables AI models to:
- 🎯 **Identify important regions** in images using bounding boxes
- 💭 **Reason step-by-step** like humans (Chain-of-Thought)
- 💡 **Answer questions** about visual content with interpretable explanations
### 📊 Dataset & Model
- **438K** Q&A pairs with bounding box annotations
- **13 diverse benchmarks** (DocVQA, GQA, TextVQA, etc.)
- **LLaVA-1.5 based** architecture with CLIP ViT-L/14
""")
with gr.Tabs():
# ============================================================
# Tab 1: Interactive Demo
# ============================================================
with gr.Tab("🎨 Interactive Demo"):
gr.Markdown("""
### Try Visual-CoT with Your Own Images!
Upload an image and ask a question. The model will:
1. **Detect** the region of interest (ROI) → Output bounding box
2. **Analyze** the ROI and full image → Generate answer
""")
with gr.Row():
with gr.Column(scale=1):
# Input
image_input = gr.Image(
type="pil",
label="📸 Upload Image",
height=400,
)
question_input = gr.Textbox(
label="❓ Your Question",
placeholder="Example: What is unusual about this image?",
lines=3,
)
with gr.Accordion("⚙️ Advanced Settings", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.05,
label="🌡️ Temperature",
info="0 = Deterministic, 1 = Creative"
)
max_tokens = gr.Slider(
minimum=128,
maximum=1024,
value=512,
step=64,
label="📝 Max Output Tokens"
)
submit_btn = gr.Button("🚀 Analyze Image", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ Clear", size="sm")
with gr.Column(scale=1):
# Output
gr.Markdown("### 📤 Results")
with gr.Group():
gr.Markdown("#### 🎯 Step 1: Region Detection")
bbox_output = gr.Textbox(
label="Detected Bounding Box",
lines=2,
show_copy_button=True,
)
with gr.Group():
gr.Markdown("#### 💡 Step 2: Answer")
answer_output = gr.Textbox(
label="Final Answer",
lines=6,
show_copy_button=True,
)
with gr.Group():
gr.Markdown("#### 🖼️ Visualization")
image_output = gr.Image(
label="Image with Bounding Box",
type="pil",
height=350,
)
info_output = gr.Textbox(
label="Processing Info",
lines=1,
visible=False,
)
# Example images
gr.Markdown("### 📋 Try These Examples")
gr.Examples(
examples=[
["examples/extreme_ironing.jpg", "What is unusual about this image?"],
["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
],
inputs=[image_input, question_input],
label="Click to load example",
)
# Event handlers
submit_btn.click(
fn=generate_viscot_response,
inputs=[image_input, question_input, temperature, max_tokens],
outputs=[bbox_output, answer_output, image_output, info_output],
)
clear_btn.click(
fn=lambda: (None, "", "", "", None, ""),
outputs=[image_input, question_input, bbox_output, answer_output, image_output, info_output],
)
# ============================================================
# Tab 2: Benchmark Explorer
# ============================================================
with gr.Tab("📊 Benchmark Explorer"):
gr.Markdown("""
### Explore Visual-CoT Benchmark Examples
Select a benchmark dataset and browse annotated examples from our evaluation suite.
These examples showcase the model's performance across diverse visual reasoning tasks.
""")
with gr.Row():
dataset_dropdown = gr.Dropdown(
choices=BENCHMARK_DATASETS,
value="gqa",
label="🗂️ Select Benchmark Dataset",
info="Choose from 13 diverse benchmarks"
)
load_examples_btn = gr.Button("📥 Load Examples", variant="secondary")
benchmark_gallery = gr.Gallery(
label="Benchmark Examples",
columns=3,
height=400,
object_fit="contain",
)
benchmark_info = gr.Markdown("""
**Select a dataset and click "Load Examples" to view benchmark samples.**
Available benchmarks:
- **DocVQA**: Document visual question answering
- **GQA**: Scene graph question answering
- **TextVQA**: Text-based VQA
- **Flickr30k**: Image captioning & grounding
- **InfographicsVQA**: Infographic understanding
- **OpenImages**: Object detection & description
- And more...
""")
# Placeholder for benchmark loading (would need actual implementation)
load_examples_btn.click(
fn=lambda x: gr.Info(f"Loading {x} examples... (Feature coming soon!)"),
inputs=[dataset_dropdown],
outputs=None,
)
# ============================================================
# Tab 3: About & Paper
# ============================================================
with gr.Tab("📚 About"):
gr.Markdown("""
## 📄 Paper Information
**Title:** Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive Dataset and Benchmark for Chain-of-Thought Reasoning
**Authors:** Hao Shao, Shengju Qian, Han Xiao, Guanglu Song, Zhuofan Zong, Letian Wang, Yu Liu, Hongsheng Li
**Conference:** NeurIPS 2024 (Spotlight) 🎉
**Abstract:**
We introduce Visual-CoT, a comprehensive dataset and benchmark for evaluating chain-of-thought reasoning
in multi-modal language models. Our dataset comprises 438K question-answer pairs with intermediate bounding
box annotations highlighting key regions essential for answering questions. We propose a multi-turn processing
pipeline that dynamically focuses on visual inputs and provides interpretable reasoning steps.
---
## 🏗️ Model Architecture
```
┌─────────────────────────────────────┐
│ Visual-CoT Pipeline │
├─────────────────────────────────────┤
│ │
│ 📸 Image Input │
│ ↓ │
│ 🔍 CLIP ViT-L/14 (Vision Encoder) │
│ ↓ │
│ 🔗 MLP Projector (2-layer) │
│ ↓ │
│ 🧠 LLaMA/Vicuna (Language Model) │
│ ↓ │
│ ┌──────────────┐ │
│ │ Step 1: ROI │ → Bounding Box │
│ └──────────────┘ │
│ ↓ │
│ ┌──────────────┐ │
│ │ Step 2: QA │ → Final Answer │
│ └──────────────┘ │
│ │
└─────────────────────────────────────┘
```
---
## 📊 Key Results
- **Detection Accuracy**: 75.3% (IoU > 0.5)
- **Answer Accuracy**: 82.7% (GPT-3.5 evaluated)
- **Benchmarks**: State-of-the-art on 10+ visual reasoning tasks
- **Model Sizes**: 7B and 13B parameters
- **Resolutions**: 224px and 336px
---
## 🔗 Resources
- 📄 **Paper**: [arXiv:2403.16999](https://arxiv.org/abs/2403.16999)
- 💻 **Code**: [GitHub](https://github.com/deepcs233/Visual-CoT)
- 🤗 **Dataset**: [Hugging Face](https://huggingface.co/datasets/deepcs233/Visual-CoT)
- 🌐 **Project Page**: [https://hao-shao.com/projects/viscot.html](https://hao-shao.com/projects/viscot.html)
- 🎯 **Models**:
- [VisCoT-7b-224](https://huggingface.co/deepcs233/VisCoT-7b-224)
- [VisCoT-7b-336](https://huggingface.co/deepcs233/VisCoT-7b-336)
- [VisCoT-13b-224](https://huggingface.co/deepcs233/VisCoT-13b-224)
- [VisCoT-13b-336](https://huggingface.co/deepcs233/VisCoT-13b-336)
---
## 📜 Citation
If you find our work useful, please cite:
```bibtex
@article{shao2024visual,
title={Visual CoT: Unleashing Chain-of-Thought Reasoning in Multi-Modal Language Models},
author={Shao, Hao and Qian, Shengju and Xiao, Han and Song, Guanglu and Zong, Zhuofan and Wang, Letian and Liu, Yu and Li, Hongsheng},
journal={arXiv preprint arXiv:2403.16999},
year={2024}
}
```
---
## ⚖️ License
- **Code**: Apache License 2.0
- **Dataset**: Research use only
- **Models**: Subject to base LLM license (LLaMA)
---
## 🙏 Acknowledgements
This work is built upon:
- [LLaVA](https://github.com/haotian-liu/LLaVA) - Base architecture
- [Shikra](https://github.com/shikras/shikra) - Positional annotations
- [Vicuna](https://github.com/lm-sys/FastChat) - Language model
- [CLIP](https://github.com/openai/CLIP) - Vision encoder
""")
# Footer
gr.Markdown("""
---
<div style="text-align: center; color: #666; padding: 20px;">
<p>🚀 Powered by <a href="https://huggingface.co/docs/hub/spaces-zerogpu">Zero GPU</a> on Hugging Face Spaces</p>
<p>Made with ❤️ by the Visual-CoT Team</p>
</div>
""")
return demo
# =============================================================================
# Launch
# =============================================================================
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=20) # Enable queue for Zero GPU
demo.launch()