capradeepgujaran commited on
Commit
8ad7e0c
1 Parent(s): b04a1d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -44
app.py CHANGED
@@ -11,86 +11,144 @@ import gradio as gr
11
  import tempfile
12
  import os
13
  import shutil
 
 
 
 
 
14
 
15
  class VideoRAGTool:
16
  def __init__(self, clip_model_name: str = "openai/clip-vit-base-patch32",
17
  blip_model_name: str = "Salesforce/blip-image-captioning-base"):
18
- """
19
- Initialize the Video RAG Tool with CLIP and BLIP models for frame analysis and captioning.
20
- """
21
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
- # Initialize CLIP for frame retrieval
24
  self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
25
  self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
26
-
27
- # Initialize BLIP for image captioning
28
  self.blip_processor = BlipProcessor.from_pretrained(blip_model_name)
29
  self.blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name).to(self.device)
30
 
 
 
 
 
 
 
 
31
  self.frame_index = None
32
  self.frame_data = []
33
  self.logger = self._setup_logger()
34
 
35
- def _setup_logger(self) -> logging.Logger:
36
- logger = logging.getLogger('VideoRAGTool')
37
- logger.setLevel(logging.INFO)
38
- handler = logging.StreamHandler()
39
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
40
- handler.setFormatter(formatter)
41
- logger.addHandler(handler)
42
- return logger
43
-
44
  def generate_caption(self, image: Image.Image) -> str:
45
- """Generate a description for the given image using BLIP."""
46
  inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
47
- out = self.blip_model.generate(**inputs)
48
- caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
49
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def process_video(self, video_path: str, frame_interval: int = 30) -> None:
52
- """Process video file and extract features from frames."""
53
  self.logger.info(f"Processing video: {video_path}")
 
 
54
  cap = cv2.VideoCapture(video_path)
55
- frame_count = 0
 
 
 
 
 
56
  features_list = []
 
57
 
58
- while cap.isOpened():
59
- ret, frame = cap.read()
60
- if not ret:
61
- break
62
-
63
- if frame_count % frame_interval == 0:
64
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
- image = Image.fromarray(frame_rgb)
66
-
67
- # Generate caption for the frame
68
- caption = self.generate_caption(image)
69
 
70
- # Process frame with CLIP
71
- inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
72
- image_features = self.clip_model.get_image_features(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- self.frame_data.append({
75
- 'frame_number': frame_count,
76
- 'timestamp': frame_count / cap.get(cv2.CAP_PROP_FPS),
77
- 'caption': caption
78
- })
79
- features_list.append(image_features.cpu().detach().numpy())
80
-
81
- frame_count += 1
82
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  cap.release()
84
 
85
  if not features_list:
86
  raise ValueError("No frames were processed from the video")
87
-
 
88
  features_array = np.vstack(features_list)
89
  self.frame_index = faiss.IndexFlatL2(features_array.shape[1])
90
  self.frame_index.add(features_array)
91
 
92
  self.logger.info(f"Processed {len(self.frame_data)} frames from video")
93
 
 
94
  def query_video(self, query_text: str, k: int = 5) -> List[Dict]:
95
  """Query the video using natural language and return relevant frames."""
96
  self.logger.info(f"Processing query: {query_text}")
 
11
  import tempfile
12
  import os
13
  import shutil
14
+ from tqdm import tqdm
15
+ import torch.nn as nn
16
+ import math
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ import numpy as np
19
 
20
  class VideoRAGTool:
21
  def __init__(self, clip_model_name: str = "openai/clip-vit-base-patch32",
22
  blip_model_name: str = "Salesforce/blip-image-captioning-base"):
23
+ """Initialize with performance optimizations."""
 
 
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
+ # Initialize models with optimization flags
27
  self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
28
  self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
 
 
29
  self.blip_processor = BlipProcessor.from_pretrained(blip_model_name)
30
  self.blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name).to(self.device)
31
 
32
+ # Enable eval mode for inference
33
+ self.clip_model.eval()
34
+ self.blip_model.eval()
35
+
36
+ # Batch processing settings
37
+ self.batch_size = 8 # Adjust based on your GPU memory
38
+
39
  self.frame_index = None
40
  self.frame_data = []
41
  self.logger = self._setup_logger()
42
 
43
+ @torch.no_grad() # Disable gradient computation for inference
 
 
 
 
 
 
 
 
44
  def generate_caption(self, image: Image.Image) -> str:
45
+ """Optimized caption generation."""
46
  inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
47
+ out = self.blip_model.generate(**inputs, max_length=30, num_beams=2)
48
+ return self.blip_processor.decode(out[0], skip_special_tokens=True)
49
+
50
+ def get_video_info(self, video_path: str) -> Tuple[int, float]:
51
+ """Get video frame count and FPS."""
52
+ cap = cv2.VideoCapture(video_path)
53
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
54
+ fps = cap.get(cv2.CAP_PROP_FPS)
55
+ cap.release()
56
+ return total_frames, fps
57
+
58
+ def preprocess_frame(self, frame: np.ndarray, target_size: Tuple[int, int] = (224, 224)) -> Image.Image:
59
+ """Preprocess frame with resizing for efficiency."""
60
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
61
+ image = Image.fromarray(frame_rgb)
62
+ return image.resize(target_size, Image.LANCZOS)
63
+
64
+ @torch.no_grad()
65
+ def process_batch(self, frames: List[Image.Image]) -> Tuple[np.ndarray, List[str]]:
66
+ """Process a batch of frames efficiently."""
67
+ # CLIP processing
68
+ clip_inputs = self.clip_processor(images=frames, return_tensors="pt", padding=True).to(self.device)
69
+ image_features = self.clip_model.get_image_features(**clip_inputs)
70
+
71
+ # BLIP processing
72
+ captions = []
73
+ blip_inputs = self.blip_processor(images=frames, return_tensors="pt", padding=True).to(self.device)
74
+ out = self.blip_model.generate(**blip_inputs, max_length=30, num_beams=2)
75
+
76
+ for o in out:
77
+ caption = self.blip_processor.decode(o, skip_special_tokens=True)
78
+ captions.append(caption)
79
+
80
+ return image_features.cpu().numpy(), captions
81
 
82
  def process_video(self, video_path: str, frame_interval: int = 30) -> None:
83
+ """Optimized video processing with batching and progress tracking."""
84
  self.logger.info(f"Processing video: {video_path}")
85
+
86
+ total_frames, fps = self.get_video_info(video_path)
87
  cap = cv2.VideoCapture(video_path)
88
+
89
+ # Calculate total batches for progress bar
90
+ frames_to_process = total_frames // frame_interval
91
+ total_batches = math.ceil(frames_to_process / self.batch_size)
92
+
93
+ current_batch = []
94
  features_list = []
95
+ frame_count = 0
96
 
97
+ with tqdm(total=frames_to_process, desc="Processing frames") as pbar:
98
+ while cap.isOpened():
99
+ ret, frame = cap.read()
100
+ if not ret:
101
+ break
 
 
 
 
 
 
102
 
103
+ if frame_count % frame_interval == 0:
104
+ # Preprocess frame
105
+ processed_frame = self.preprocess_frame(frame)
106
+ current_batch.append(processed_frame)
107
+
108
+ # Process batch when it reaches batch_size
109
+ if len(current_batch) == self.batch_size:
110
+ batch_features, batch_captions = self.process_batch(current_batch)
111
+
112
+ # Store results
113
+ for i, (features, caption) in enumerate(zip(batch_features, batch_captions)):
114
+ batch_frame_number = frame_count - (self.batch_size - i - 1) * frame_interval
115
+ self.frame_data.append({
116
+ 'frame_number': batch_frame_number,
117
+ 'timestamp': batch_frame_number / fps,
118
+ 'caption': caption
119
+ })
120
+ features_list.append(features)
121
+
122
+ current_batch = []
123
+ pbar.update(self.batch_size)
124
 
125
+ frame_count += 1
 
 
 
 
 
 
 
126
 
127
+ # Process remaining frames
128
+ if current_batch:
129
+ batch_features, batch_captions = self.process_batch(current_batch)
130
+ for i, (features, caption) in enumerate(zip(batch_features, batch_captions)):
131
+ batch_frame_number = frame_count - (len(current_batch) - i - 1) * frame_interval
132
+ self.frame_data.append({
133
+ 'frame_number': batch_frame_number,
134
+ 'timestamp': batch_frame_number / fps,
135
+ 'caption': caption
136
+ })
137
+ features_list.append(features)
138
+
139
  cap.release()
140
 
141
  if not features_list:
142
  raise ValueError("No frames were processed from the video")
143
+
144
+ # Create FAISS index
145
  features_array = np.vstack(features_list)
146
  self.frame_index = faiss.IndexFlatL2(features_array.shape[1])
147
  self.frame_index.add(features_array)
148
 
149
  self.logger.info(f"Processed {len(self.frame_data)} frames from video")
150
 
151
+
152
  def query_video(self, query_text: str, k: int = 5) -> List[Dict]:
153
  """Query the video using natural language and return relevant frames."""
154
  self.logger.info(f"Processing query: {query_text}")