File size: 3,965 Bytes
7ec133b
 
 
 
9db455c
 
7ec133b
 
c52e238
fb293c4
c52e238
7ec133b
f8dcf83
a4115fd
 
 
f8dcf83
7ec133b
9db455c
b499d7f
7294f1e
 
b499d7f
7294f1e
 
 
 
 
b499d7f
 
 
7294f1e
 
 
 
 
 
b499d7f
 
 
 
 
 
7294f1e
 
 
9db455c
b499d7f
 
 
 
 
 
 
 
 
 
 
 
 
edff486
feb8185
 
 
 
9db455c
b499d7f
9db455c
edff486
634326a
9db455c
 
 
 
 
 
 
 
 
 
b499d7f
9db455c
7294f1e
9db455c
 
b499d7f
 
9db455c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ec133b
edff486
 
 
7ec133b
 
 
 
b499d7f
 
 
 
7ec133b
b499d7f
 
 
7ec133b
 
b499d7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import cv2
import numpy as np


# # Ensure GPU usage if available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("ManishThota/SparrowVQE",
                                             torch_dtype=torch.float16, 
                                             device_map="auto",
                                             trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ManishThota/SparrowVQE", trust_remote_code=True)


def video_to_frames(video, fps=1):
    """Converts a video file into frames and stores them as PNG images in a list."""
    frames_png = []
    cap = cv2.VideoCapture(video)
    
    if not cap.isOpened():
        print("Error opening video file")
        return frames_png
    
    frame_count = 0
    frame_interval = int(cap.get(cv2.CAP_PROP_FPS)) // fps  # Calculate frame interval
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("Can't receive frame (stream end?). Exiting ...")
            break
        
        if frame_count % frame_interval == 0:
            is_success, buffer = cv2.imencode(".png", frame)
            if is_success:
                frames_png.append(np.array(buffer).tobytes())
        
        frame_count += 1
    
    cap.release()
    return frames_png

def extract_frames(frame):

    # Convert binary data to a numpy array
    frame_np = np.frombuffer(frame, dtype=np.uint8)

    # Decode the PNG image
    image_rgb = cv2.imdecode(frame_np, flags=cv2.IMREAD_COLOR)  # Assuming it's in RGB format

    # Convert RGB to BGR
    image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)

    return image_bgr

def predict_answer(image, video, question, max_tokens=100):

    text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{question}? ASSISTANT:"
    input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)

    
    if image is not None:
        # Process as an image
        image = image.convert("RGB")
        image_tensor = model.image_preprocess(image)
        
        #Generate the answer
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_tokens,
            images=image_tensor,
            use_cache=True)[0]
        
        return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
        
    elif video is not None:
        # Process as a video
        frames = video_to_frames(video)
        answers = []
        for frame in frames:
            image = extract_frames(frame)
            image_tensor = model.image_preprocess(image)
            
            # Generate the answer
            output_ids = model.generate(
                input_ids,
                max_new_tokens=max_tokens,
                images=image_tensor,
                use_cache=True)[0]
            
            answer = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
            answers.append(answer)
        return "\n".join(answers)
        
    else:
        return "Unsupported file type. Please upload an image or video."
        



def gradio_predict(image, video, question, max_tokens):
    answer = predict_answer(image, video, question, max_tokens)
    return answer

iface = gr.Interface(
    fn=gradio_predict,
    inputs=[
        gr.Image(type="pil", label="Upload or Drag an Image"),
        gr.Video(label="Upload your video here"),
    ],
    outputs=gr.TextArea(label="Answer"),
    # outputs=gr.Image(label="Output"),
    title="Video/Image Viewer",
    description="Upload an image or video to view it or extract frames from the video.",
)

iface.launch(debug=True)