Hanzo03 commited on
Commit
b28f276
·
verified ·
1 Parent(s): 08d3193

Update utils/engine.py

Browse files
Files changed (1) hide show
  1. utils/engine.py +8 -6
utils/engine.py CHANGED
@@ -9,13 +9,13 @@ from typing import Tuple, List
9
 
10
  from utils.config import config, get_logger
11
  from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
 
12
  logger = get_logger("Engine")
13
 
14
  def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
15
  if not video_path:
16
  return "Please upload a video.", []
17
 
18
- # Strict Cache Cleanup
19
  if os.path.exists(config.cache_dir):
20
  logger.info(f"Clearing old cache at {config.cache_dir}...")
21
  shutil.rmtree(config.cache_dir, ignore_errors=True)
@@ -32,7 +32,6 @@ def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
32
  rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
33
  h, w, c = rgb_first.shape
34
 
35
- # 🚨 STRICT SSD ALLOCATION
36
  logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
37
  frame_cache = zarr.create_array(
38
  config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
@@ -41,7 +40,6 @@ def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
41
  timestamps, count, frame_idx = [], 0, 0
42
 
43
  while success:
44
- # 🚀 SPEED OPTIMIZATION: Only process exact frames needed
45
  if count % frame_interval == 0:
46
  rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
47
  frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
@@ -63,7 +61,9 @@ def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
63
  inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
64
 
65
  with torch.no_grad():
66
- features = clip_model.get_image_features(**inputs)
 
 
67
 
68
  normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
69
  all_embeddings.extend(normalized)
@@ -90,12 +90,14 @@ def ask_video_question(query: str) -> Tuple[str, List[Image.Image]]:
90
 
91
  inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
92
  with torch.no_grad():
93
- text_features = clip_model.get_text_features(**inputs)
 
 
 
94
  text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
95
 
96
  results = collection.query(query_embeddings=text_embedding, n_results=3)
97
 
98
- # Read strictly from SSD
99
  frame_cache = zarr.open_array(config.cache_dir, mode="r")
100
 
101
  retrieved_images = []
 
9
 
10
  from utils.config import config, get_logger
11
  from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
12
+
13
  logger = get_logger("Engine")
14
 
15
  def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
16
  if not video_path:
17
  return "Please upload a video.", []
18
 
 
19
  if os.path.exists(config.cache_dir):
20
  logger.info(f"Clearing old cache at {config.cache_dir}...")
21
  shutil.rmtree(config.cache_dir, ignore_errors=True)
 
32
  rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
33
  h, w, c = rgb_first.shape
34
 
 
35
  logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
36
  frame_cache = zarr.create_array(
37
  config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
 
40
  timestamps, count, frame_idx = [], 0, 0
41
 
42
  while success:
 
43
  if count % frame_interval == 0:
44
  rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
45
  frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
 
61
  inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
62
 
63
  with torch.no_grad():
64
+ # 🚨 BUGFIX: Manually extract and project the vision features
65
+ vision_outputs = clip_model.vision_model(**inputs)
66
+ features = clip_model.visual_projection(vision_outputs.pooler_output)
67
 
68
  normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
69
  all_embeddings.extend(normalized)
 
90
 
91
  inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
92
  with torch.no_grad():
93
+ # 🚨 BUGFIX: Manually extract and project the text features
94
+ text_outputs = clip_model.text_model(**inputs)
95
+ text_features = clip_model.text_projection(text_outputs.pooler_output)
96
+
97
  text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
98
 
99
  results = collection.query(query_embeddings=text_embedding, n_results=3)
100
 
 
101
  frame_cache = zarr.open_array(config.cache_dir, mode="r")
102
 
103
  retrieved_images = []