import os # Configure Hugging Face caches to use the writable /cache volume in Spaces os.environ["HF_HOME"] = "/cache" os.environ["TRANSFORMERS_CACHE"] = "/cache" os.environ["HF_DATASETS_CACHE"] = "/cache" from flask import Flask, request, jsonify import numpy as np import torch import av import cv2 import tempfile import shutil import logging from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor from PIL import Image from torchvision.transforms import Compose, Resize, ToTensor # Initialize Flask app app = Flask(__name__) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Globals for model, processor, and transforms device = "cuda" if torch.cuda.is_available() else "cpu" model = None processor = None transform = None def load_model(): """Load the model and processor into globals""" global model, processor, transform if model is None: model_name = "OPear/videomae-large-finetuned-UCF-Crime" logger.info(f"Loading model {model_name} on device {device}") # Downloads will go to /cache automatically model = VideoMAEForVideoClassification.from_pretrained(model_name).to(device) processor = VideoMAEImageProcessor.from_pretrained(model_name) transform = Compose([ Resize((224, 224)), ToTensor(), ]) logger.info("Model and processor loaded successfully") return model, processor, transform def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0): """Uniformly sample exactly 16 frame indices from a clip""" if seg_len <= clip_len: indices = np.linspace(0, seg_len - 1, num=clip_len, dtype=int) else: end_idx = np.random.randint(clip_len, seg_len) start_idx = max(0, end_idx - clip_len) indices = np.linspace(start_idx, end_idx - 1, num=clip_len, dtype=int) return np.clip(indices, 0, seg_len - 1) def process_video(video_path): """Extract 16 uniformly-sampled frames from the video""" try: container = av.open(video_path) video_stream = container.streams.video[0] seg_len = video_stream.frames if video_stream.frames > 0 else int( cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT) ) except Exception as e: logger.error(f"Error opening video: {e}") return None, None indices = sample_frame_indices(clip_len=16, seg_len=seg_len) frames = [] # Try PyAV decode try: container.seek(0) for i, frame in enumerate(container.decode(video=0)): if i > indices[-1]: break if i in indices: frames.append(frame.to_ndarray(format="rgb24")) except Exception as e: logger.warning(f"PyAV decoding failed, falling back to OpenCV: {e}") # Fallback to OpenCV if necessary if len(frames) < len(indices): cap = cv2.VideoCapture(video_path) for i in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, i) ret, frame = cap.read() if ret: frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) cap.release() if len(frames) != 16: logger.error(f"Expected 16 frames, got {len(frames)}") return None, None return np.stack(frames), indices def predict_video(frames): """Run inference on a stack of 16 frames""" model, processor, transform = load_model() video_tensor = torch.stack([transform(Image.fromarray(f)) for f in frames]) video_tensor = video_tensor.unsqueeze(0) inputs = processor(list(video_tensor[0]), return_tensors="pt", do_rescale=False) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits pred_id = logits.argmax(-1).item() return model.config.id2label.get(pred_id, "Unknown") @app.route('/classify-video', methods=['POST']) def classify_video(): if 'video' not in request.files: return jsonify({'error': 'No video file provided'}), 400 file = request.files['video'] if file.filename == '': return jsonify({'error': 'Empty filename'}), 400 temp_dir = tempfile.mkdtemp() path = os.path.join(temp_dir, file.filename) try: file.save(path) frames, _ = process_video(path) if frames is None: return jsonify({'error': 'Failed to extract frames'}), 400 prediction = predict_video(frames) return jsonify({'prediction': prediction}) except Exception as e: logger.exception(f"Error during processing: {e}") return jsonify({'error': str(e)}), 500 finally: shutil.rmtree(temp_dir, ignore_errors=True) @app.route('/health', methods=['GET']) def health_check(): return jsonify({'status': 'healthy'}), 200 if __name__ == '__main__': # Preload model on startup logger.info("Starting application and loading model...") load_model() port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port, debug=False)