test / main.py
OneEyeDJ's picture
Update main.py
40afd1d verified
import sys
sys.path.append('.')
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
import argparse
import os
import spaces # Import spaces for ZEROGPU
class SimpleVideoLLaMA3Interface:
def __init__(self, model_path):
self.model_path = model_path
self.model = None
self.processor = None
self.image_formats = ("png", "jpg", "jpeg", "bmp", "gif", "webp")
self.video_formats = ("mp4", "avi", "mov", "mkv", "webm", "m4v", "3gp", "flv")
# Load processor on CPU (doesn't need GPU)
print(f"Loading processor from {model_path}...")
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
print("Processor loaded successfully!")
def load_model(self):
"""Load model - this will be called inside GPU-decorated functions"""
if self.model is None:
print(f"Loading model from {self.model_path}...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("Model loaded successfully!")
@spaces.GPU(duration=120) # Allocate GPU for up to 120 seconds
@torch.inference_mode()
def predict(self, messages, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=4096, fps=10, max_frames=256):
# Load model inside GPU context
self.load_model()
if not messages or len(messages) == 0:
return messages
# Convert Gradio messages to VideoLLaMA3 format with PROPER conversation history
conversation = []
# Group messages into proper conversation turns
i = 0
while i < len(messages):
if messages[i]["role"] == "user":
# Collect all consecutive user messages into one turn
user_content = []
while i < len(messages) and messages[i]["role"] == "user":
msg = messages[i]
print(f"DEBUG: Processing user message {i}: {msg}")
print(f"DEBUG: Content type: {type(msg['content'])}")
print(f"DEBUG: Content value: {msg['content']}")
# Handle different types of user content
if isinstance(msg["content"], str):
print(f"DEBUG: Adding text: {msg['content']}")
user_content.append({"type": "text", "text": msg["content"]})
elif isinstance(msg["content"], tuple) and len(msg["content"]) > 0:
# Handle file uploads from Gradio (comes as tuple)
file_path = msg["content"][0]
print(f"Processing file from tuple: {file_path}")
# Check if file exists and add appropriate content
if not os.path.exists(file_path):
print(f"ERROR: File does not exist: {file_path}")
user_content.append({"type": "text", "text": f"Error: Could not find file {file_path}"})
elif file_path.lower().endswith(self.video_formats):
print(f"βœ… DETECTED VIDEO: Adding video with fps={fps}, max_frames={max_frames}")
user_content.append({"type": "video", "video": {"video_path": file_path, "fps": fps, "max_frames": max_frames}})
elif file_path.lower().endswith(self.image_formats):
print(f"βœ… DETECTED IMAGE: Adding image: {file_path}")
user_content.append({"type": "image", "image": {"image_path": file_path}})
else:
print(f"❌ UNKNOWN FILE TYPE: {file_path}")
user_content.append({"type": "text", "text": f"Unsupported file type: {file_path}"})
elif isinstance(msg["content"], dict) and "path" in msg["content"]:
# Handle file uploads with path dict (backup method)
file_path = msg["content"]["path"]
print(f"Processing file from dict: {file_path}")
if not os.path.exists(file_path):
print(f"ERROR: File does not exist: {file_path}")
user_content.append({"type": "text", "text": f"Error: Could not find file {file_path}"})
elif file_path.lower().endswith(self.video_formats):
print(f"βœ… DETECTED VIDEO: Adding video with fps={fps}, max_frames={max_frames}")
user_content.append({"type": "video", "video": {"video_path": file_path, "fps": fps, "max_frames": max_frames}})
elif file_path.lower().endswith(self.image_formats):
print(f"βœ… DETECTED IMAGE: Adding image: {file_path}")
user_content.append({"type": "image", "image": {"image_path": file_path}})
else:
print(f"❌ UNKNOWN FILE TYPE: {file_path}")
user_content.append({"type": "text", "text": f"Unsupported file type: {file_path}"})
i += 1
# Add the complete user turn to conversation
if user_content:
conversation.append({"role": "user", "content": user_content})
print(f"πŸ“ Added user turn with {len(user_content)} items: {[item.get('type', 'unknown') for item in user_content]}")
elif messages[i]["role"] == "assistant":
# Add assistant response
conversation.append({"role": "assistant", "content": messages[i]["content"]})
print(f"πŸ€– Added assistant turn: {messages[i]['content'][:50]}...")
i += 1
if not conversation:
return messages
try:
# Debug: Print conversation structure
print(f"Conversation structure: {len(conversation)} turns")
for i, turn in enumerate(conversation):
role = turn["role"]
if role == "user":
content_types = [item.get("type", "unknown") for item in turn["content"] if isinstance(item, dict)]
print(f"Turn {i}: {role} - {content_types}")
else:
print(f"Turn {i}: {role} - text response")
inputs = self.processor(
conversation=conversation,
add_system_prompt=True,
add_generation_prompt=True,
return_tensors="pt"
)
inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
output_ids = self.model.generate(
**inputs,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens
)
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
# Extract just the response part (after the last assistant prompt)
# Find the last occurrence of common assistant indicators
for indicator in ["assistant", "Assistant", "ASSISTANT"]:
if indicator in response:
response = response.split(indicator)[-1].strip()
break
# Clean up common formatting artifacts
response = response.lstrip(":")
response = response.lstrip()
messages.append({"role": "assistant", "content": response})
return messages
except Exception as e:
error_msg = f"Error: {str(e)}"
print(f"Error in prediction: {error_msg}")
messages.append({"role": "assistant", "content": error_msg})
return messages
def create_interface(self):
with gr.Blocks(title="VideoLLaMA3 AI Curator") as interface:
gr.Markdown("# 🎬 VideoLLaMA3 AI Curator\nUpload images or videos and ask questions!")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(type="messages", height=600)
with gr.Column(scale=1):
with gr.Tab("Input"):
video_input = gr.Video(sources=["upload"], label="Upload Video")
image_input = gr.Image(sources=["upload"], type="filepath", label="Upload Image")
text_input = gr.Textbox(label="Your Message", placeholder="Ask about the image/video or chat...")
submit_btn = gr.Button("Send", variant="primary")
with gr.Tab("Settings"):
do_sample = gr.Checkbox(value=True, label="Do Sample")
temperature = gr.Slider(0.0, 1.0, value=0.7, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P")
max_tokens = gr.Slider(256, 8192, value=4096, step=64, label="Max Tokens")
fps = gr.Slider(0.5, 15.0, value=10.0, label="Video FPS")
max_frames = gr.Slider(32, 512, value=256, step=8, label="Max Frames")
def add_file(history, file):
if file:
print(f"DEBUG: Gradio file input: {file}")
print(f"DEBUG: File type: {type(file)}")
history.append({"role": "user", "content": {"path": file}})
return history, None
def add_text(history, text):
if text.strip():
history.append({"role": "user", "content": text})
return history, ""
def respond(history, do_sample, temperature, top_p, max_tokens, fps, max_frames):
# Only predict if the last message is from user and we haven't responded to it yet
if history and history[-1]["role"] == "user":
return self.predict(history, do_sample, temperature, top_p, max_tokens, fps, max_frames)
return history
video_input.change(add_file, [chatbot, video_input], [chatbot, video_input])
image_input.change(add_file, [chatbot, image_input], [chatbot, image_input])
text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then(
respond, [chatbot, do_sample, temperature, top_p, max_tokens, fps, max_frames], [chatbot]
)
submit_btn.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(
respond, [chatbot, do_sample, temperature, top_p, max_tokens, fps, max_frames], [chatbot]
)
return interface
# For Hugging Face Spaces
app = SimpleVideoLLaMA3Interface("DAMO-NLP-SG/VideoLLaMA3-7B")
interface = app.create_interface()
if __name__ == "__main__":
interface.launch()