Spaces:
Running
Running
Update utils/engine.py
Browse files- utils/engine.py +8 -6
utils/engine.py
CHANGED
|
@@ -9,13 +9,13 @@ from typing import Tuple, List
|
|
| 9 |
|
| 10 |
from utils.config import config, get_logger
|
| 11 |
from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
|
|
|
|
| 12 |
logger = get_logger("Engine")
|
| 13 |
|
| 14 |
def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
|
| 15 |
if not video_path:
|
| 16 |
return "Please upload a video.", []
|
| 17 |
|
| 18 |
-
# Strict Cache Cleanup
|
| 19 |
if os.path.exists(config.cache_dir):
|
| 20 |
logger.info(f"Clearing old cache at {config.cache_dir}...")
|
| 21 |
shutil.rmtree(config.cache_dir, ignore_errors=True)
|
|
@@ -32,7 +32,6 @@ def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
|
|
| 32 |
rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 33 |
h, w, c = rgb_first.shape
|
| 34 |
|
| 35 |
-
# 🚨 STRICT SSD ALLOCATION
|
| 36 |
logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
|
| 37 |
frame_cache = zarr.create_array(
|
| 38 |
config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
|
|
@@ -41,7 +40,6 @@ def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
|
|
| 41 |
timestamps, count, frame_idx = [], 0, 0
|
| 42 |
|
| 43 |
while success:
|
| 44 |
-
# 🚀 SPEED OPTIMIZATION: Only process exact frames needed
|
| 45 |
if count % frame_interval == 0:
|
| 46 |
rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 47 |
frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
|
|
@@ -63,7 +61,9 @@ def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
|
|
| 63 |
inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
|
| 64 |
|
| 65 |
with torch.no_grad():
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
|
| 68 |
normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
|
| 69 |
all_embeddings.extend(normalized)
|
|
@@ -90,12 +90,14 @@ def ask_video_question(query: str) -> Tuple[str, List[Image.Image]]:
|
|
| 90 |
|
| 91 |
inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
|
| 92 |
with torch.no_grad():
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
|
| 95 |
|
| 96 |
results = collection.query(query_embeddings=text_embedding, n_results=3)
|
| 97 |
|
| 98 |
-
# Read strictly from SSD
|
| 99 |
frame_cache = zarr.open_array(config.cache_dir, mode="r")
|
| 100 |
|
| 101 |
retrieved_images = []
|
|
|
|
| 9 |
|
| 10 |
from utils.config import config, get_logger
|
| 11 |
from utils.models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
|
| 12 |
+
|
| 13 |
logger = get_logger("Engine")
|
| 14 |
|
| 15 |
def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
|
| 16 |
if not video_path:
|
| 17 |
return "Please upload a video.", []
|
| 18 |
|
|
|
|
| 19 |
if os.path.exists(config.cache_dir):
|
| 20 |
logger.info(f"Clearing old cache at {config.cache_dir}...")
|
| 21 |
shutil.rmtree(config.cache_dir, ignore_errors=True)
|
|
|
|
| 32 |
rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 33 |
h, w, c = rgb_first.shape
|
| 34 |
|
|
|
|
| 35 |
logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
|
| 36 |
frame_cache = zarr.create_array(
|
| 37 |
config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
|
|
|
|
| 40 |
timestamps, count, frame_idx = [], 0, 0
|
| 41 |
|
| 42 |
while success:
|
|
|
|
| 43 |
if count % frame_interval == 0:
|
| 44 |
rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
| 45 |
frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
|
|
|
|
| 61 |
inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
|
| 62 |
|
| 63 |
with torch.no_grad():
|
| 64 |
+
# 🚨 BUGFIX: Manually extract and project the vision features
|
| 65 |
+
vision_outputs = clip_model.vision_model(**inputs)
|
| 66 |
+
features = clip_model.visual_projection(vision_outputs.pooler_output)
|
| 67 |
|
| 68 |
normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
|
| 69 |
all_embeddings.extend(normalized)
|
|
|
|
| 90 |
|
| 91 |
inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
|
| 92 |
with torch.no_grad():
|
| 93 |
+
# 🚨 BUGFIX: Manually extract and project the text features
|
| 94 |
+
text_outputs = clip_model.text_model(**inputs)
|
| 95 |
+
text_features = clip_model.text_projection(text_outputs.pooler_output)
|
| 96 |
+
|
| 97 |
text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
|
| 98 |
|
| 99 |
results = collection.query(query_embeddings=text_embedding, n_results=3)
|
| 100 |
|
|
|
|
| 101 |
frame_cache = zarr.open_array(config.cache_dir, mode="r")
|
| 102 |
|
| 103 |
retrieved_images = []
|