|
import sys |
|
sys.path.append('.') |
|
|
|
import torch |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
import argparse |
|
import os |
|
import spaces |
|
|
|
class SimpleVideoLLaMA3Interface: |
|
def __init__(self, model_path): |
|
self.model_path = model_path |
|
self.model = None |
|
self.processor = None |
|
self.image_formats = ("png", "jpg", "jpeg", "bmp", "gif", "webp") |
|
self.video_formats = ("mp4", "avi", "mov", "mkv", "webm", "m4v", "3gp", "flv") |
|
|
|
|
|
print(f"Loading processor from {model_path}...") |
|
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) |
|
print("Processor loaded successfully!") |
|
|
|
def load_model(self): |
|
"""Load model - this will be called inside GPU-decorated functions""" |
|
if self.model is None: |
|
print(f"Loading model from {self.model_path}...") |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_path, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2", |
|
) |
|
print("Model loaded successfully!") |
|
|
|
@spaces.GPU(duration=120) |
|
@torch.inference_mode() |
|
def predict(self, messages, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=4096, fps=10, max_frames=256): |
|
|
|
self.load_model() |
|
|
|
if not messages or len(messages) == 0: |
|
return messages |
|
|
|
|
|
conversation = [] |
|
|
|
|
|
i = 0 |
|
while i < len(messages): |
|
if messages[i]["role"] == "user": |
|
|
|
user_content = [] |
|
|
|
while i < len(messages) and messages[i]["role"] == "user": |
|
msg = messages[i] |
|
print(f"DEBUG: Processing user message {i}: {msg}") |
|
print(f"DEBUG: Content type: {type(msg['content'])}") |
|
print(f"DEBUG: Content value: {msg['content']}") |
|
|
|
|
|
if isinstance(msg["content"], str): |
|
print(f"DEBUG: Adding text: {msg['content']}") |
|
user_content.append({"type": "text", "text": msg["content"]}) |
|
elif isinstance(msg["content"], tuple) and len(msg["content"]) > 0: |
|
|
|
file_path = msg["content"][0] |
|
print(f"Processing file from tuple: {file_path}") |
|
|
|
|
|
if not os.path.exists(file_path): |
|
print(f"ERROR: File does not exist: {file_path}") |
|
user_content.append({"type": "text", "text": f"Error: Could not find file {file_path}"}) |
|
elif file_path.lower().endswith(self.video_formats): |
|
print(f"β
DETECTED VIDEO: Adding video with fps={fps}, max_frames={max_frames}") |
|
user_content.append({"type": "video", "video": {"video_path": file_path, "fps": fps, "max_frames": max_frames}}) |
|
elif file_path.lower().endswith(self.image_formats): |
|
print(f"β
DETECTED IMAGE: Adding image: {file_path}") |
|
user_content.append({"type": "image", "image": {"image_path": file_path}}) |
|
else: |
|
print(f"β UNKNOWN FILE TYPE: {file_path}") |
|
user_content.append({"type": "text", "text": f"Unsupported file type: {file_path}"}) |
|
elif isinstance(msg["content"], dict) and "path" in msg["content"]: |
|
|
|
file_path = msg["content"]["path"] |
|
print(f"Processing file from dict: {file_path}") |
|
|
|
if not os.path.exists(file_path): |
|
print(f"ERROR: File does not exist: {file_path}") |
|
user_content.append({"type": "text", "text": f"Error: Could not find file {file_path}"}) |
|
elif file_path.lower().endswith(self.video_formats): |
|
print(f"β
DETECTED VIDEO: Adding video with fps={fps}, max_frames={max_frames}") |
|
user_content.append({"type": "video", "video": {"video_path": file_path, "fps": fps, "max_frames": max_frames}}) |
|
elif file_path.lower().endswith(self.image_formats): |
|
print(f"β
DETECTED IMAGE: Adding image: {file_path}") |
|
user_content.append({"type": "image", "image": {"image_path": file_path}}) |
|
else: |
|
print(f"β UNKNOWN FILE TYPE: {file_path}") |
|
user_content.append({"type": "text", "text": f"Unsupported file type: {file_path}"}) |
|
|
|
i += 1 |
|
|
|
|
|
if user_content: |
|
conversation.append({"role": "user", "content": user_content}) |
|
print(f"π Added user turn with {len(user_content)} items: {[item.get('type', 'unknown') for item in user_content]}") |
|
|
|
elif messages[i]["role"] == "assistant": |
|
|
|
conversation.append({"role": "assistant", "content": messages[i]["content"]}) |
|
print(f"π€ Added assistant turn: {messages[i]['content'][:50]}...") |
|
i += 1 |
|
|
|
if not conversation: |
|
return messages |
|
|
|
try: |
|
|
|
print(f"Conversation structure: {len(conversation)} turns") |
|
for i, turn in enumerate(conversation): |
|
role = turn["role"] |
|
if role == "user": |
|
content_types = [item.get("type", "unknown") for item in turn["content"] if isinstance(item, dict)] |
|
print(f"Turn {i}: {role} - {content_types}") |
|
else: |
|
print(f"Turn {i}: {role} - text response") |
|
|
|
inputs = self.processor( |
|
conversation=conversation, |
|
add_system_prompt=True, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
) |
|
inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} |
|
if "pixel_values" in inputs: |
|
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) |
|
|
|
output_ids = self.model.generate( |
|
**inputs, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_new_tokens=max_new_tokens |
|
) |
|
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
|
|
|
|
|
|
|
for indicator in ["assistant", "Assistant", "ASSISTANT"]: |
|
if indicator in response: |
|
response = response.split(indicator)[-1].strip() |
|
break |
|
|
|
|
|
response = response.lstrip(":") |
|
response = response.lstrip() |
|
|
|
messages.append({"role": "assistant", "content": response}) |
|
return messages |
|
|
|
except Exception as e: |
|
error_msg = f"Error: {str(e)}" |
|
print(f"Error in prediction: {error_msg}") |
|
messages.append({"role": "assistant", "content": error_msg}) |
|
return messages |
|
|
|
def create_interface(self): |
|
with gr.Blocks(title="VideoLLaMA3 AI Curator") as interface: |
|
gr.Markdown("# π¬ VideoLLaMA3 AI Curator\nUpload images or videos and ask questions!") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot(type="messages", height=600) |
|
|
|
with gr.Column(scale=1): |
|
with gr.Tab("Input"): |
|
video_input = gr.Video(sources=["upload"], label="Upload Video") |
|
image_input = gr.Image(sources=["upload"], type="filepath", label="Upload Image") |
|
text_input = gr.Textbox(label="Your Message", placeholder="Ask about the image/video or chat...") |
|
submit_btn = gr.Button("Send", variant="primary") |
|
|
|
with gr.Tab("Settings"): |
|
do_sample = gr.Checkbox(value=True, label="Do Sample") |
|
temperature = gr.Slider(0.0, 1.0, value=0.7, label="Temperature") |
|
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P") |
|
max_tokens = gr.Slider(256, 8192, value=4096, step=64, label="Max Tokens") |
|
fps = gr.Slider(0.5, 15.0, value=10.0, label="Video FPS") |
|
max_frames = gr.Slider(32, 512, value=256, step=8, label="Max Frames") |
|
|
|
def add_file(history, file): |
|
if file: |
|
print(f"DEBUG: Gradio file input: {file}") |
|
print(f"DEBUG: File type: {type(file)}") |
|
history.append({"role": "user", "content": {"path": file}}) |
|
return history, None |
|
|
|
def add_text(history, text): |
|
if text.strip(): |
|
history.append({"role": "user", "content": text}) |
|
return history, "" |
|
|
|
def respond(history, do_sample, temperature, top_p, max_tokens, fps, max_frames): |
|
|
|
if history and history[-1]["role"] == "user": |
|
return self.predict(history, do_sample, temperature, top_p, max_tokens, fps, max_frames) |
|
return history |
|
|
|
video_input.change(add_file, [chatbot, video_input], [chatbot, video_input]) |
|
image_input.change(add_file, [chatbot, image_input], [chatbot, image_input]) |
|
text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then( |
|
respond, [chatbot, do_sample, temperature, top_p, max_tokens, fps, max_frames], [chatbot] |
|
) |
|
submit_btn.click(add_text, [chatbot, text_input], [chatbot, text_input]).then( |
|
respond, [chatbot, do_sample, temperature, top_p, max_tokens, fps, max_frames], [chatbot] |
|
) |
|
|
|
return interface |
|
|
|
|
|
app = SimpleVideoLLaMA3Interface("DAMO-NLP-SG/VideoLLaMA3-7B") |
|
interface = app.create_interface() |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |