vigneshwar472 commited on
Commit
7892ef3
β€’
1 Parent(s): 707e96f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+
4
+ import subprocess # πŸ₯²
5
+ subprocess.run(
6
+ "pip install flash-attn --no-build-isolation",
7
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
8
+ shell=True,
9
+ )
10
+ # subprocess.run(
11
+ # "pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git",
12
+ # shell=True,
13
+ # )
14
+
15
+ import torch
16
+ from llava.model.builder import load_pretrained_model
17
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
18
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
19
+ from llava.conversation import conv_templates, SeparatorStyle
20
+ import copy
21
+ import warnings
22
+ from decord import VideoReader, cpu
23
+ import numpy as np
24
+ import tempfile
25
+ import os
26
+ import shutil
27
+ #warnings.filterwarnings("ignore")
28
+ title = "# Demo of VLM on Crime scenes"
29
+ description1 ="""The **πŸŒ‹πŸ“ΉLLaVA-Video-7B-Qwen2** is a 7B parameter model trained on the πŸŒ‹πŸ“ΉLLaVA-Video-178K dataset and the LLaVA-OneVision dataset. It is [based on the **Qwen2 language model**](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f), supporting a context window of up to 32K tokens. The model can process and interact with images, multi-images, and videos, with specific optimizations for video analysis.
30
+ This model leverages the **SO400M vision backbone** for visual input and Qwen2 for language processing, making it highly efficient in multi-modal reasoning, including visual and video-based tasks.
31
+ πŸŒ‹πŸ“ΉLLaVA-Video has larger variants of [32B](https://huggingface.co/lmms-lab/LLaVA-NeXT-Video-32B-Qwen) and [72B](https://huggingface.co/lmms-lab/LLaVA-Video-72B-Qwen2) and with a [variant](https://huggingface.co/lmms-lab/LLaVA-Video-7B-Qwen2-Video-Only) only trained on the new synthetic data
32
+ For further details, please visit the [Project Page](https://github.com/LLaVA-VL/LLaVA-NeXT) or check out the corresponding [research paper](https://arxiv.org/abs/2410.02713).
33
+ - **Architecture**: `LlavaQwenForCausalLM`
34
+ - **Attention Heads**: 28
35
+ - **Hidden Layers**: 28
36
+ - **Hidden Size**: 3584
37
+ """
38
+ description2 ="""
39
+ We have leveraged this VLM for Crime scene video description. The expected performance is achieved and we thank everyone who made this possible.
40
+ """
41
+
42
+
43
+ def load_video(video_path, max_frames_num, fps=1, force_sample=False):
44
+ if max_frames_num == 0:
45
+ return np.zeros((1, 336, 336, 3))
46
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
47
+ total_frame_num = len(vr)
48
+ video_time = total_frame_num / vr.get_avg_fps()
49
+ fps = round(vr.get_avg_fps()/fps)
50
+ frame_idx = [i for i in range(0, len(vr), fps)]
51
+ frame_time = [i/fps for i in frame_idx]
52
+ if len(frame_idx) > max_frames_num or force_sample:
53
+ sample_fps = max_frames_num
54
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
55
+ frame_idx = uniform_sampled_frames.tolist()
56
+ frame_time = [i/vr.get_avg_fps() for i in frame_idx]
57
+ frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
58
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
59
+ return spare_frames, frame_time, video_time
60
+
61
+ # Load the model
62
+ pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
63
+ model_name = "llava_qwen"
64
+ device = "cuda" if torch.cuda.is_available() else "cpu"
65
+ device_map = "auto"
66
+
67
+ print("Loading model...")
68
+ tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
69
+ model.eval()
70
+ print("Model loaded successfully!")
71
+
72
+ @spaces.GPU
73
+ def process_video(video_path, question):
74
+ max_frames_num = 64
75
+ video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
76
+ video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
77
+ video = [video]
78
+
79
+ conv_template = "qwen_1_5"
80
+ time_instruction = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}. Please answer the following questions related to this video."
81
+
82
+ full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
83
+
84
+ conv = copy.deepcopy(conv_templates[conv_template])
85
+ conv.append_message(conv.roles[0], full_question)
86
+ conv.append_message(conv.roles[1], None)
87
+ prompt_question = conv.get_prompt()
88
+
89
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
90
+
91
+ with torch.no_grad():
92
+ output = model.generate(
93
+ input_ids,
94
+ images=video,
95
+ modalities=["video"],
96
+ do_sample=False,
97
+ temperature=0,
98
+ max_new_tokens=4096,
99
+ )
100
+
101
+ response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
102
+ return response
103
+
104
+ def gradio_interface(video_file, question):
105
+ if video_file is None:
106
+ return "Please upload a video file."
107
+ response = process_video(video_file, question)
108
+ return response
109
+
110
+ with gr.Blocks() as demo:
111
+ gr.Markdown(title)
112
+ with gr.Row():
113
+ with gr.Group():
114
+ gr.Markdown(description1)
115
+ with gr.Group():
116
+ gr.Markdown(description2)
117
+ with gr.Row():
118
+ with gr.Column():
119
+ video_input = gr.Video()
120
+ question_input = gr.Textbox(label="πŸ™‹πŸ»β€β™‚οΈUser Question", placeholder="Ask a question about the video... or Ask to describe the video")
121
+ submit_button = gr.Button("Ask")
122
+ output = gr.Textbox(label="VLM Bot")
123
+
124
+ submit_button.click(
125
+ fn=gradio_interface,
126
+ inputs=[video_input, question_input],
127
+ outputs=output
128
+ )
129
+
130
+ if __name__ == "__main__":
131
+ demo.launch(show_error=True, ssr_mode = False)