Hanzo03 commited on
Commit
04705fd
·
1 Parent(s): 2297897
Files changed (13) hide show
  1. .env +1 -0
  2. .gitattributes copy +45 -0
  3. .gitignore +2 -0
  4. .python-version +1 -0
  5. README copy.md +12 -0
  6. app.py +29 -0
  7. pyproject.toml +19 -0
  8. requirements.txt +11 -0
  9. utils/__init__.py +0 -0
  10. utils/config.py +34 -0
  11. utils/engine.py +110 -0
  12. utils/models.py +29 -0
  13. uv.lock +0 -0
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ HUGGINGFACEHUB_API_TOKEN = "your_huggingface_api_key_here"
.gitattributes copy ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ # Python-generated files
37
+ __pycache__/
38
+ *.py[oc]
39
+ build/
40
+ dist/
41
+ wheels/
42
+ *.egg-info
43
+
44
+ # Virtual environments
45
+ .venv
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ .venv/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README copy.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RAG
3
+ emoji: 👁
4
+ colorFrom: pink
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 6.1.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.config import get_logger
3
+ from utils.engine import process_and_index_video, ask_video_question
4
+
5
+ logger = get_logger("GradioUI")
6
+ logger.info("Constructing UI...")
7
+
8
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
9
+ gr.Markdown("# 🧠 Multimodal Video RAG (Vision Q/A)")
10
+
11
+ with gr.Row():
12
+ with gr.Column(scale=1):
13
+ video_input = gr.Video(label="Upload Video")
14
+ index_btn = gr.Button("1. Process & Index Video", variant="primary")
15
+ status_out = gr.Textbox(label="System Status", interactive=False)
16
+
17
+ with gr.Column(scale=1):
18
+ query_input = gr.Textbox(label="Ask a visual question:")
19
+ ask_btn = gr.Button("2. Ask Question")
20
+ answer_out = gr.Textbox(label="VLM Answer", lines=4)
21
+
22
+ gallery_out = gr.Gallery(label="Context Frames", show_label=True, columns=3)
23
+
24
+ index_btn.click(fn=process_and_index_video, inputs=[video_input], outputs=[status_out, gallery_out])
25
+ ask_btn.click(fn=ask_video_question, inputs=[query_input], outputs=[answer_out, gallery_out])
26
+
27
+ if __name__ == "__main__":
28
+ logger.info("Launching server...")
29
+ demo.launch()
pyproject.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rag"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "opencv-python",
9
+ "gradio",
10
+ "pillow",
11
+ "torch",
12
+ "transformers",
13
+ "chromadb",
14
+ "zarr",
15
+ "einops>=0.8.2",
16
+ "torchvision>=0.25.0",
17
+ "pydantic>=2.12.5",
18
+ "hf-transfer>=0.1.9",
19
+ ]
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python-headless
3
+ transformers
4
+ chromadb
5
+ torch
6
+ torchvision
7
+ pillow
8
+ zarr
9
+ pydantic
10
+ einops
11
+ hf_transfer
utils/__init__.py ADDED
File without changes
utils/config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pydantic import BaseModel, Field
4
+
5
+ # 🚀 SPEED OPTIMIZATION: Force HF to use the high-speed Rust transfer protocol
6
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
7
+
8
+ # Set standard logging
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
12
+ datefmt="%H:%M:%S"
13
+ )
14
+
15
+ # 🔇 SILENCE THE HTTP SPAM
16
+ logging.getLogger("httpx").setLevel(logging.WARNING)
17
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
18
+ logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
19
+ logging.getLogger("uvicorn.error").setLevel(logging.WARNING)
20
+
21
+ def get_logger(name: str):
22
+ return logging.getLogger(name)
23
+
24
+ class AppConfig(BaseModel):
25
+ # FORCE Zarr to use the guaranteed-writable /tmp directory on HF Spaces
26
+ cache_dir: str = Field(default="/tmp/video_cache.zarr", description="Strict Zarr v3 SSD cache")
27
+ clip_model_id: str = Field(default="openai/clip-vit-base-patch32")
28
+ vlm_model_id: str = Field(default="vikhyatk/moondream2")
29
+ vlm_revision: str = Field(default="2024-08-26")
30
+ collection_name: str = Field(default="multimodal_rag")
31
+ default_fps: int = Field(default=1)
32
+ batch_size: int = Field(default=64, description="Batch size for faster CLIP processing")
33
+
34
+ config = AppConfig()
utils/engine.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import cv2
4
+ import torch
5
+ import numpy as np
6
+ import zarr
7
+ from PIL import Image
8
+ from typing import Tuple, List
9
+
10
+ from config import config, get_logger
11
+ from models import device, clip_processor, clip_model, collection, chroma_client, vlm_model, vlm_tokenizer
12
+ logger = get_logger("Engine")
13
+
14
+ def process_and_index_video(video_path: str) -> Tuple[str, List[Image.Image]]:
15
+ if not video_path:
16
+ return "Please upload a video.", []
17
+
18
+ # Strict Cache Cleanup
19
+ if os.path.exists(config.cache_dir):
20
+ logger.info(f"Clearing old cache at {config.cache_dir}...")
21
+ shutil.rmtree(config.cache_dir, ignore_errors=True)
22
+
23
+ logger.info("Starting fast extraction process...")
24
+ vidcap = cv2.VideoCapture(video_path)
25
+ video_fps = vidcap.get(cv2.CAP_PROP_FPS)
26
+ frame_interval = max(1, int(video_fps / config.default_fps))
27
+
28
+ success, first_frame = vidcap.read()
29
+ if not success:
30
+ return "Failed to read video.", []
31
+
32
+ rgb_first = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
33
+ h, w, c = rgb_first.shape
34
+
35
+ # 🚨 STRICT SSD ALLOCATION
36
+ logger.info(f"Allocating strict Zarr v3 SSD cache at {config.cache_dir}...")
37
+ frame_cache = zarr.create_array(
38
+ config.cache_dir, shape=(0, h, w, c), chunks=(10, h, w, c), dtype='uint8', zarr_format=3
39
+ )
40
+
41
+ timestamps, count, frame_idx = [], 0, 0
42
+
43
+ while success:
44
+ # 🚀 SPEED OPTIMIZATION: Only process exact frames needed
45
+ if count % frame_interval == 0:
46
+ rgb_image = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
47
+ frame_cache.append(np.expand_dims(rgb_image, axis=0), axis=0)
48
+ timestamps.append(count / video_fps)
49
+ frame_idx += 1
50
+
51
+ success, first_frame = vidcap.read()
52
+ count += 1
53
+
54
+ vidcap.release()
55
+
56
+ logger.info("Generating CLIP embeddings in batches...")
57
+ all_embeddings = []
58
+ total_frames = frame_cache.shape[0]
59
+
60
+ for i in range(0, total_frames, config.batch_size):
61
+ batch_arrays = frame_cache[i : i + config.batch_size]
62
+ batch_pil = [Image.fromarray(arr) for arr in batch_arrays]
63
+ inputs = clip_processor(images=batch_pil, return_tensors="pt").to(device)
64
+
65
+ with torch.no_grad():
66
+ features = clip_model.get_image_features(**inputs)
67
+
68
+ normalized = (features / features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
69
+ all_embeddings.extend(normalized)
70
+
71
+ logger.info("Indexing into ChromaDB...")
72
+ ids = [f"frame_{i}" for i in range(total_frames)]
73
+ metadatas = [{"timestamp": ts, "frame_idx": i} for i, ts in enumerate(timestamps)]
74
+
75
+ global collection
76
+ chroma_client.delete_collection(config.collection_name)
77
+ collection = chroma_client.create_collection(config.collection_name)
78
+
79
+ collection.add(embeddings=all_embeddings, metadatas=metadatas, ids=ids)
80
+
81
+ sample_frames = [Image.fromarray(frame_cache[i]) for i in range(min(3, total_frames))]
82
+ return f"Processed {total_frames} frames strictly on SSD cache.", sample_frames
83
+
84
+
85
+ def ask_video_question(query: str) -> Tuple[str, List[Image.Image]]:
86
+ if collection.count() == 0:
87
+ return "Please process a video first.", []
88
+
89
+ logger.info(f"Processing query: '{query}'")
90
+
91
+ inputs = clip_processor(text=[query], return_tensors="pt", padding=True).to(device)
92
+ with torch.no_grad():
93
+ text_features = clip_model.get_text_features(**inputs)
94
+ text_embedding = (text_features / text_features.norm(p=2, dim=-1, keepdim=True)).cpu().tolist()
95
+
96
+ results = collection.query(query_embeddings=text_embedding, n_results=3)
97
+
98
+ # Read strictly from SSD
99
+ frame_cache = zarr.open_array(config.cache_dir, mode="r")
100
+
101
+ retrieved_images = []
102
+ for metadata in results['metadatas'][0]:
103
+ img_array = frame_cache[int(metadata['frame_idx'])]
104
+ retrieved_images.append(Image.fromarray(img_array))
105
+
106
+ logger.info("Generating VLM answer...")
107
+ encoded_image = vlm_model.encode_image(retrieved_images[0])
108
+ answer = vlm_model.answer_question(encoded_image, query, vlm_tokenizer)
109
+
110
+ return answer, retrieved_images
utils/models.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import chromadb
3
+ from transformers import CLIPProcessor, CLIPModel, AutoModelForCausalLM, AutoTokenizer
4
+ from config import config, get_logger
5
+
6
+ logger = get_logger("Models")
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ logger.info(f"Initializing models on: {device.upper()}")
9
+
10
+ # 1. Load CLIP
11
+ logger.info(f"Loading CLIP ({config.clip_model_id})...")
12
+ clip_processor = CLIPProcessor.from_pretrained(config.clip_model_id)
13
+ clip_model = CLIPModel.from_pretrained(config.clip_model_id).to(device)
14
+
15
+ # 2. Initialize ChromaDB
16
+ logger.info("Initializing ChromaDB...")
17
+ chroma_client = chromadb.Client()
18
+ try:
19
+ chroma_client.delete_collection(config.collection_name)
20
+ except Exception:
21
+ pass
22
+ collection = chroma_client.create_collection(name=config.collection_name)
23
+
24
+ # 3. Load VLM
25
+ logger.info(f"Loading VLM ({config.vlm_model_id})...")
26
+ vlm_model = AutoModelForCausalLM.from_pretrained(
27
+ config.vlm_model_id, trust_remote_code=True, revision=config.vlm_revision
28
+ ).to(device)
29
+ vlm_tokenizer = AutoTokenizer.from_pretrained(config.vlm_model_id, revision=config.vlm_revision)
uv.lock ADDED
The diff for this file is too large to render. See raw diff