File size: 11,602 Bytes
bdaed29
 
 
 
 
 
 
 
40afd1d
bdaed29
 
 
40afd1d
 
 
bdaed29
 
40afd1d
 
 
 
 
bdaed29
40afd1d
 
 
 
 
 
 
 
 
 
 
 
 
 
bdaed29
 
40afd1d
 
 
bdaed29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40afd1d
 
 
bdaed29
40afd1d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import sys
sys.path.append('.')

import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
import argparse
import os
import spaces  # Import spaces for ZEROGPU

class SimpleVideoLLaMA3Interface:
    def __init__(self, model_path):
        self.model_path = model_path
        self.model = None
        self.processor = None
        self.image_formats = ("png", "jpg", "jpeg", "bmp", "gif", "webp")
        self.video_formats = ("mp4", "avi", "mov", "mkv", "webm", "m4v", "3gp", "flv")
        
        # Load processor on CPU (doesn't need GPU)
        print(f"Loading processor from {model_path}...")
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        print("Processor loaded successfully!")

    def load_model(self):
        """Load model - this will be called inside GPU-decorated functions"""
        if self.model is None:
            print(f"Loading model from {self.model_path}...")
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                trust_remote_code=True,
                device_map="auto",
                torch_dtype=torch.bfloat16,
                attn_implementation="flash_attention_2",
            )
            print("Model loaded successfully!")

    @spaces.GPU(duration=120)  # Allocate GPU for up to 120 seconds
    @torch.inference_mode()
    def predict(self, messages, do_sample=True, temperature=0.7, top_p=0.9, max_new_tokens=4096, fps=10, max_frames=256):
        # Load model inside GPU context
        self.load_model()
        
        if not messages or len(messages) == 0:
            return messages
        
        # Convert Gradio messages to VideoLLaMA3 format with PROPER conversation history
        conversation = []
        
        # Group messages into proper conversation turns
        i = 0
        while i < len(messages):
            if messages[i]["role"] == "user":
                # Collect all consecutive user messages into one turn
                user_content = []
                
                while i < len(messages) and messages[i]["role"] == "user":
                    msg = messages[i]
                    print(f"DEBUG: Processing user message {i}: {msg}")
                    print(f"DEBUG: Content type: {type(msg['content'])}")
                    print(f"DEBUG: Content value: {msg['content']}")
                    
                    # Handle different types of user content
                    if isinstance(msg["content"], str):
                        print(f"DEBUG: Adding text: {msg['content']}")
                        user_content.append({"type": "text", "text": msg["content"]})
                    elif isinstance(msg["content"], tuple) and len(msg["content"]) > 0:
                        # Handle file uploads from Gradio (comes as tuple)
                        file_path = msg["content"][0]
                        print(f"Processing file from tuple: {file_path}")
                        
                        # Check if file exists and add appropriate content
                        if not os.path.exists(file_path):
                            print(f"ERROR: File does not exist: {file_path}")
                            user_content.append({"type": "text", "text": f"Error: Could not find file {file_path}"})
                        elif file_path.lower().endswith(self.video_formats):
                            print(f"βœ… DETECTED VIDEO: Adding video with fps={fps}, max_frames={max_frames}")
                            user_content.append({"type": "video", "video": {"video_path": file_path, "fps": fps, "max_frames": max_frames}})
                        elif file_path.lower().endswith(self.image_formats):
                            print(f"βœ… DETECTED IMAGE: Adding image: {file_path}")
                            user_content.append({"type": "image", "image": {"image_path": file_path}})
                        else:
                            print(f"❌ UNKNOWN FILE TYPE: {file_path}")
                            user_content.append({"type": "text", "text": f"Unsupported file type: {file_path}"})
                    elif isinstance(msg["content"], dict) and "path" in msg["content"]:
                        # Handle file uploads with path dict (backup method)
                        file_path = msg["content"]["path"]
                        print(f"Processing file from dict: {file_path}")
                        
                        if not os.path.exists(file_path):
                            print(f"ERROR: File does not exist: {file_path}")
                            user_content.append({"type": "text", "text": f"Error: Could not find file {file_path}"})
                        elif file_path.lower().endswith(self.video_formats):
                            print(f"βœ… DETECTED VIDEO: Adding video with fps={fps}, max_frames={max_frames}")
                            user_content.append({"type": "video", "video": {"video_path": file_path, "fps": fps, "max_frames": max_frames}})
                        elif file_path.lower().endswith(self.image_formats):
                            print(f"βœ… DETECTED IMAGE: Adding image: {file_path}")
                            user_content.append({"type": "image", "image": {"image_path": file_path}})
                        else:
                            print(f"❌ UNKNOWN FILE TYPE: {file_path}")
                            user_content.append({"type": "text", "text": f"Unsupported file type: {file_path}"})
                    
                    i += 1
                
                # Add the complete user turn to conversation
                if user_content:
                    conversation.append({"role": "user", "content": user_content})
                    print(f"πŸ“ Added user turn with {len(user_content)} items: {[item.get('type', 'unknown') for item in user_content]}")
            
            elif messages[i]["role"] == "assistant":
                # Add assistant response
                conversation.append({"role": "assistant", "content": messages[i]["content"]})
                print(f"πŸ€– Added assistant turn: {messages[i]['content'][:50]}...")
                i += 1

        if not conversation:
            return messages

        try:
            # Debug: Print conversation structure
            print(f"Conversation structure: {len(conversation)} turns")
            for i, turn in enumerate(conversation):
                role = turn["role"]
                if role == "user":
                    content_types = [item.get("type", "unknown") for item in turn["content"] if isinstance(item, dict)]
                    print(f"Turn {i}: {role} - {content_types}")
                else:
                    print(f"Turn {i}: {role} - text response")
            
            inputs = self.processor(
                conversation=conversation,
                add_system_prompt=True,
                add_generation_prompt=True,
                return_tensors="pt"
            )
            inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
            if "pixel_values" in inputs:
                inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)

            output_ids = self.model.generate(
                **inputs,
                do_sample=do_sample,
                temperature=temperature,
                top_p=top_p,
                max_new_tokens=max_new_tokens
            )
            response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
            
            # Extract just the response part (after the last assistant prompt)
            # Find the last occurrence of common assistant indicators
            for indicator in ["assistant", "Assistant", "ASSISTANT"]:
                if indicator in response:
                    response = response.split(indicator)[-1].strip()
                    break
            
            # Clean up common formatting artifacts
            response = response.lstrip(":")
            response = response.lstrip()
            
            messages.append({"role": "assistant", "content": response})
            return messages
        
        except Exception as e:
            error_msg = f"Error: {str(e)}"
            print(f"Error in prediction: {error_msg}")
            messages.append({"role": "assistant", "content": error_msg})
            return messages

    def create_interface(self):
        with gr.Blocks(title="VideoLLaMA3 AI Curator") as interface:
            gr.Markdown("# 🎬 VideoLLaMA3 AI Curator\nUpload images or videos and ask questions!")
            
            with gr.Row():
                with gr.Column(scale=2):
                    chatbot = gr.Chatbot(type="messages", height=600)
                
                with gr.Column(scale=1):
                    with gr.Tab("Input"):
                        video_input = gr.Video(sources=["upload"], label="Upload Video")
                        image_input = gr.Image(sources=["upload"], type="filepath", label="Upload Image")
                        text_input = gr.Textbox(label="Your Message", placeholder="Ask about the image/video or chat...")
                        submit_btn = gr.Button("Send", variant="primary")
                    
                    with gr.Tab("Settings"):
                        do_sample = gr.Checkbox(value=True, label="Do Sample")
                        temperature = gr.Slider(0.0, 1.0, value=0.7, label="Temperature")
                        top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P")
                        max_tokens = gr.Slider(256, 8192, value=4096, step=64, label="Max Tokens")
                        fps = gr.Slider(0.5, 15.0, value=10.0, label="Video FPS")
                        max_frames = gr.Slider(32, 512, value=256, step=8, label="Max Frames")

            def add_file(history, file):
                if file:
                    print(f"DEBUG: Gradio file input: {file}")
                    print(f"DEBUG: File type: {type(file)}")
                    history.append({"role": "user", "content": {"path": file}})
                return history, None

            def add_text(history, text):
                if text.strip():
                    history.append({"role": "user", "content": text})
                return history, ""

            def respond(history, do_sample, temperature, top_p, max_tokens, fps, max_frames):
                # Only predict if the last message is from user and we haven't responded to it yet
                if history and history[-1]["role"] == "user":
                    return self.predict(history, do_sample, temperature, top_p, max_tokens, fps, max_frames)
                return history

            video_input.change(add_file, [chatbot, video_input], [chatbot, video_input])
            image_input.change(add_file, [chatbot, image_input], [chatbot, image_input])
            text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then(
                respond, [chatbot, do_sample, temperature, top_p, max_tokens, fps, max_frames], [chatbot]
            )
            submit_btn.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(
                respond, [chatbot, do_sample, temperature, top_p, max_tokens, fps, max_frames], [chatbot]
            )

        return interface

# For Hugging Face Spaces
app = SimpleVideoLLaMA3Interface("DAMO-NLP-SG/VideoLLaMA3-7B")
interface = app.create_interface()

if __name__ == "__main__":
    interface.launch()