from flask import Flask, request, jsonify import os import numpy as np import torch import av import cv2 import tempfile import shutil import logging from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor from PIL import Image from torchvision.transforms import Compose, Resize, ToTensor app = Flask(__name__) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables to store model and processor device = "cuda" if torch.cuda.is_available() else "cpu" model = None processor = None transform = None def load_model(): """Load the model and processor""" global model, processor, transform if model is None: model_name = "OPear/videomae-large-finetuned-UCF-Crime" logger.info(f"Loading model {model_name} on {device}...") model = VideoMAEForVideoClassification.from_pretrained(model_name).to(device) processor = VideoMAEImageProcessor.from_pretrained(model_name) transform = Compose([ Resize((224, 224)), ToTensor(), ]) logger.info("Model loaded successfully") return model, processor, transform def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0): """Samples exactly 16 frames uniformly from the video.""" if seg_len <= clip_len: indices = np.linspace(0, seg_len - 1, num=clip_len, dtype=int) else: end_idx = np.random.randint(clip_len, seg_len) start_idx = max(0, end_idx - clip_len) indices = np.linspace(start_idx, end_idx - 1, num=clip_len, dtype=int) return np.clip(indices, 0, seg_len - 1) def process_video(video_path): try: container = av.open(video_path) video_stream = container.streams.video[0] seg_len = video_stream.frames if video_stream.frames > 0 else int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT)) except Exception as e: logger.error(f"Error opening video: {str(e)}") return None, None indices = sample_frame_indices(clip_len=16, seg_len=seg_len) frames = [] try: container.seek(0) for i, frame in enumerate(container.decode(video=0)): if i > indices[-1]: break if i in indices: frames.append(frame.to_ndarray(format="rgb24")) except Exception as e: logger.error(f"Error decoding video with PyAV: {str(e)}") if not frames: logger.info("Falling back to OpenCV for frame extraction") cap = cv2.VideoCapture(video_path) for i in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, i) ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) cap.release() if len(frames) != 16: logger.error(f"Could not extract 16 frames, got {len(frames)}") return None, None return np.stack(frames), indices def predict_video(frames): """Processes frames and runs VideoMAE classification.""" model, processor, transform = load_model() video_tensor = torch.stack([transform(Image.fromarray(frame)) for frame in frames]) video_tensor = video_tensor.unsqueeze(0) # Add batch dimension inputs = processor(list(video_tensor[0]), return_tensors="pt", do_rescale=False) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): # Disable gradient calculation for inference outputs = model(**inputs) logits = outputs.logits predicted_class = logits.argmax(-1).item() id2label = model.config.id2label return id2label.get(predicted_class, "Unknown") @app.route('/classify-video', methods=['POST']) def classify_video(): if 'video' not in request.files: logger.warning("No video file in request") return jsonify({'error': 'No video file provided'}), 400 video_file = request.files['video'] if video_file.filename == '': logger.warning("Empty video filename") return jsonify({'error': 'No video selected'}), 400 # Create temporary directory temp_dir = tempfile.mkdtemp() video_path = os.path.join(temp_dir, video_file.filename) try: # Save the uploaded video logger.info(f"Saving uploaded video to {video_path}") video_file.save(video_path) # Process the video logger.info("Processing video...") frames, indices = process_video(video_path) if frames is None: return jsonify({'error': 'Failed to process video file'}), 400 # Get the prediction logger.info("Running prediction...") prediction = predict_video(frames) logger.info(f"Prediction result: {prediction}") return jsonify({'prediction': prediction}) except Exception as e: logger.exception(f"Error processing video: {str(e)}") return jsonify({'error': f'Error processing video: {str(e)}'}), 500 finally: # Clean up the temporary directory and its contents if os.path.exists(temp_dir): logger.info(f"Cleaning up temporary directory: {temp_dir}") shutil.rmtree(temp_dir) @app.route('/health', methods=['GET']) def health_check(): """Endpoint to check if the service is up and running""" return jsonify({"status": "healthy"}), 200 if __name__ == '__main__': # Load model at startup logger.info("Initializing application...") load_model() # Get port from environment variable or use 5000 as default port = int(os.environ.get('PORT', 7860)) logger.info(f"Starting Flask application on port {port}") app.run(host='0.0.0.0', port=port, debug=False)