vjepa2-384 / handler.py
ayushexel's picture
Update handler.py
5578f24 verified
from typing import Dict, List, Any, Union
import torch
import numpy as np
import base64
import io
import tempfile
import os
import transformers
import logging
from pathlib import Path
print("transformers version ", transformers.__version__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
"""
Custom HuggingFace Inference Endpoint Handler for V-JEPA2 Video Embeddings.
This handler processes videos and returns pooled embeddings suitable for
similarity search and vector databases like LanceDB.
Features:
- Batch processing support for efficient inference
- Handles variable-length videos via uniform frame sampling
- Supports video URLs and base64-encoded videos
- Returns 1408-dimensional pooled embeddings
"""
def __init__(self, path: str = ""):
"""
Initialize the V-JEPA2 model and processor.
Args:
path: Path to the model weights (provided by HF Inference Endpoints)
"""
try:
from transformers import AutoVideoProcessor, AutoModel
from torchcodec.decoders import VideoDecoder
logger.info(f"Loading V-JEPA2 model from {path}")
# Determine device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Load model without the classification head to get embeddings
# We use AutoModel instead of AutoModelForVideoClassification
self.model = AutoModel.from_pretrained(path).to(self.device)
self.processor = AutoVideoProcessor.from_pretrained(path)
# Set model to evaluation mode
self.model.eval()
# Store model config
self.frames_per_clip = getattr(self.model.config, 'frames_per_clip', 64)
self.hidden_size = getattr(self.model.config, 'hidden_size', 1408)
logger.info(f"Model loaded successfully. Frames per clip: {self.frames_per_clip}, Hidden size: {self.hidden_size}")
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
raise
def _load_video_from_url(self, video_url: str) -> np.ndarray:
"""
Load video from URL and sample frames.
Args:
video_url: URL to the video file
Returns:
Video tensor with shape (frames, channels, height, width)
"""
from torchcodec.decoders import VideoDecoder
try:
vr = VideoDecoder(video_url)
total_frames = len(vr)
# Uniform sampling to get exactly frames_per_clip frames
if total_frames < self.frames_per_clip:
logger.warning(f"Video has only {total_frames} frames, less than required {self.frames_per_clip}. Repeating frames.")
# Repeat frames to reach required count
frame_indices = np.tile(np.arange(total_frames),
(self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip]
else:
# Uniform sampling across the video
frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int)
video = vr.get_frames_at(indices=frame_indices).data
return video
except Exception as e:
logger.error(f"Error loading video from URL {video_url}: {str(e)}")
raise
def _load_video_from_base64(self, video_b64: str) -> np.ndarray:
"""
Load video from base64-encoded data.
Args:
video_b64: Base64-encoded video data
Returns:
Video tensor with shape (frames, channels, height, width)
"""
from torchcodec.decoders import VideoDecoder
try:
# Decode base64
video_bytes = base64.b64decode(video_b64)
# Save to temporary file (torchcodec requires file path)
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file:
tmp_file.write(video_bytes)
tmp_path = tmp_file.name
try:
vr = VideoDecoder(tmp_path)
total_frames = len(vr)
# Uniform sampling
if total_frames < self.frames_per_clip:
frame_indices = np.tile(np.arange(total_frames),
(self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip]
else:
frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int)
video = vr.get_frames_at(indices=frame_indices).data
return video
finally:
# Clean up temporary file
os.unlink(tmp_path)
except Exception as e:
logger.error(f"Error loading video from base64: {str(e)}")
raise
def _extract_embeddings(self, videos: List[np.ndarray]) -> np.ndarray:
"""
Extract pooled embeddings from a batch of videos.
Args:
videos: List of video tensors
Returns:
Numpy array of shape (batch_size, hidden_size) containing pooled embeddings
"""
try:
# Process videos through the processor
inputs = self.processor(videos, return_tensors="pt").to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
# Extract last hidden state and pool
# Shape: (batch_size, sequence_length, hidden_size)
last_hidden_state = outputs.last_hidden_state
# Mean pooling across sequence dimension
# Shape: (batch_size, hidden_size)
pooled_embeddings = last_hidden_state.mean(dim=1)
# Convert to numpy
embeddings = pooled_embeddings.cpu().numpy()
return embeddings
except Exception as e:
logger.error(f"Error extracting embeddings: {str(e)}")
raise
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process inference request.
Expected input formats:
1. Single video URL:
{"inputs": "https://example.com/video.mp4"}
2. Batch of video URLs:
{"inputs": ["url1", "url2", "url3"]}
3. Base64-encoded video:
{"inputs": "base64_encoded_string", "encoding": "base64"}
4. Batch with mixed formats:
{"inputs": [...], "batch_size": 4}
Returns:
List of dictionaries containing embeddings:
[{"embedding": [1408-dim vector], "shape": [1408]}]
"""
try:
# Extract inputs
inputs = data.get("inputs")
encoding = data.get("encoding", "url")
if inputs is None:
raise ValueError("No 'inputs' provided in request data")
# Handle single input vs batch
if isinstance(inputs, str):
inputs = [inputs]
elif not isinstance(inputs, list):
raise ValueError(f"'inputs' must be a string or list, got {type(inputs)}")
logger.info(f"Processing {len(inputs)} video(s)")
# Load videos
videos = []
for idx, inp in enumerate(inputs):
try:
if encoding == "base64":
video = self._load_video_from_base64(inp)
else: # Default to URL
video = self._load_video_from_url(inp)
videos.append(video)
except Exception as e:
logger.error(f"Error loading video {idx}: {str(e)}")
# Return error for this specific video
videos.append(None)
# Filter out failed videos and track their indices
valid_videos = []
valid_indices = []
for idx, video in enumerate(videos):
if video is not None:
valid_videos.append(video)
valid_indices.append(idx)
if not valid_videos:
raise ValueError("No valid videos could be loaded")
# Extract embeddings for valid videos
embeddings = self._extract_embeddings(valid_videos)
# Prepare results
results = [None] * len(inputs)
for valid_idx, embedding in zip(valid_indices, embeddings):
results[valid_idx] = {
"embedding": embedding.tolist(),
"shape": list(embedding.shape),
"status": "success"
}
# Fill in errors for failed videos
for idx in range(len(inputs)):
if results[idx] is None:
results[idx] = {
"embedding": None,
"shape": None,
"status": "error",
"error": "Failed to load video"
}
logger.info(f"Successfully processed {len(valid_videos)}/{len(inputs)} videos")
return results
except Exception as e:
logger.error(f"Error in __call__: {str(e)}")
return [{"error": str(e), "status": "error"}]