vjepa2 / handler.py
peterproofpath's picture
Update handler.py
c8b5767 verified
"""
V-JEPA 2 Custom Inference Handler for Hugging Face Inference Endpoints
Model: facebook/vjepa2-vitl-fpc64-256 (Large variant - good balance of performance/resources)
For ProofPath video assessment - extracts motion features from skill demonstration videos.
"""
from typing import Dict, List, Any, Optional
import torch
import numpy as np
import base64
import io
import tempfile
import os
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize V-JEPA 2 model for video feature extraction.
Args:
path: Path to the model directory (provided by HF Inference Endpoints)
"""
from transformers import AutoVideoProcessor, AutoModel
# Always load from the official Facebook model on HuggingFace Hub
# (path points to /repository which is our custom handler, not the model weights)
model_id = "facebook/vjepa2-vitl-fpc64-256"
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load processor and model
self.processor = AutoVideoProcessor.from_pretrained(model_id)
self.model = AutoModel.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
attn_implementation="sdpa" # Use scaled dot product attention for efficiency
)
if not torch.cuda.is_available():
self.model = self.model.to(self.device)
self.model.eval()
# Default config
self.default_num_frames = 64 # V-JEPA 2 is trained with 64 frames
def _decode_video(self, video_data: Any) -> torch.Tensor:
"""
Decode video from various input formats.
Supports:
- Base64 encoded video bytes
- URL to video file
- Raw bytes
"""
from torchcodec.decoders import VideoDecoder
# Handle base64 encoded video
if isinstance(video_data, str):
if video_data.startswith(('http://', 'https://')):
# URL - torchcodec can handle URLs directly
vr = VideoDecoder(video_data)
elif video_data.startswith('data:'):
# Data URL format
header, encoded = video_data.split(',', 1)
video_bytes = base64.b64decode(encoded)
# Write to temp file for torchcodec
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
temp_path = f.name
vr = VideoDecoder(temp_path)
os.unlink(temp_path)
else:
# Assume base64 encoded
video_bytes = base64.b64decode(video_data)
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_bytes)
temp_path = f.name
vr = VideoDecoder(temp_path)
os.unlink(temp_path)
elif isinstance(video_data, bytes):
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
f.write(video_data)
temp_path = f.name
vr = VideoDecoder(temp_path)
os.unlink(temp_path)
else:
raise ValueError(f"Unsupported video input type: {type(video_data)}")
return vr
def _sample_frames(
self,
video_decoder,
num_frames: int = 64,
sampling_strategy: str = "uniform"
) -> torch.Tensor:
"""
Sample frames from video decoder.
Args:
video_decoder: torchcodec VideoDecoder instance
num_frames: Number of frames to sample
sampling_strategy: "uniform" or "random"
"""
# Get video metadata
metadata = video_decoder.metadata
total_frames = metadata.num_frames if hasattr(metadata, 'num_frames') else 1000
if sampling_strategy == "uniform":
# Uniformly sample frames across the video
if total_frames <= num_frames:
frame_idx = np.arange(total_frames)
else:
frame_idx = np.linspace(0, total_frames - 1, num_frames, dtype=int)
elif sampling_strategy == "random":
frame_idx = np.sort(np.random.choice(total_frames, min(num_frames, total_frames), replace=False))
else:
# Default to sequential from start
frame_idx = np.arange(min(num_frames, total_frames))
# Get frames: returns T x C x H x W
frames = video_decoder.get_frames_at(indices=frame_idx.tolist()).data
return frames
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process video and extract V-JEPA 2 features.
Expected input format:
{
"inputs": <base64_video_string or video_url>,
"parameters": {
"num_frames": 64, # Optional: number of frames to sample
"sampling_strategy": "uniform", # Optional: "uniform" or "random"
"return_predictor": true, # Optional: also return predictor features
"pooling": "mean" # Optional: "mean", "cls", or "none"
}
}
Returns:
{
"encoder_features": [...], # Encoder output features
"predictor_features": [...], # Optional predictor features
"feature_shape": [T, D], # Shape of features
}
"""
# Extract inputs
inputs = data.get("inputs")
if inputs is None:
inputs = data.get("video")
if inputs is None:
raise ValueError("No video input provided. Use 'inputs' or 'video' key.")
# Extract parameters
params = data.get("parameters", {})
num_frames = params.get("num_frames", self.default_num_frames)
sampling_strategy = params.get("sampling_strategy", "uniform")
return_predictor = params.get("return_predictor", False)
pooling = params.get("pooling", "mean")
try:
# Decode and sample video
video_decoder = self._decode_video(inputs)
frames = self._sample_frames(video_decoder, num_frames, sampling_strategy)
# Process through V-JEPA 2 processor
processed = self.processor(frames, return_tensors="pt")
processed = {k: v.to(self.model.device) for k, v in processed.items()}
# Run inference
with torch.no_grad():
outputs = self.model(**processed)
# Extract encoder features
encoder_features = outputs.last_hidden_state # [batch, seq, hidden]
# Apply pooling
if pooling == "mean":
encoder_pooled = encoder_features.mean(dim=1) # [batch, hidden]
elif pooling == "cls":
encoder_pooled = encoder_features[:, 0, :] # [batch, hidden]
else:
encoder_pooled = encoder_features # [batch, seq, hidden]
result = {
"encoder_features": encoder_pooled.cpu().numpy().tolist(),
"feature_shape": list(encoder_pooled.shape),
}
# Optionally include predictor features
if return_predictor and hasattr(outputs, 'predictor_output'):
predictor_features = outputs.predictor_output.last_hidden_state
if pooling == "mean":
predictor_pooled = predictor_features.mean(dim=1)
elif pooling == "cls":
predictor_pooled = predictor_features[:, 0, :]
else:
predictor_pooled = predictor_features
result["predictor_features"] = predictor_pooled.cpu().numpy().tolist()
result["predictor_shape"] = list(predictor_pooled.shape)
return result
except Exception as e:
return {"error": str(e), "error_type": type(e).__name__}