rajendrakumarv's picture
Update app.py (#6)
c03f135 verified
raw
history blame
5.89 kB
from flask import Flask, request, jsonify
import os
import numpy as np
import torch
import av
import cv2
import tempfile
import shutil
import logging
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
os.makedirs("./.cache", exist_ok=True)
app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables to store model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
processor = None
transform = None
def load_model():
"""Load the model and processor"""
global model, processor, transform
if model is None:
model_name = "OPear/videomae-large-finetuned-UCF-Crime"
logger.info(f"Loading model {model_name} on {device}...")
model = VideoMAEForVideoClassification.from_pretrained("OPear/videomae-large-finetuned-UCF-Crime",cache_dir="./.cache").to(device)
processor = VideoMAEImageProcessor.from_pretrained(model_name)
transform = Compose([
Resize((224, 224)),
ToTensor(),
])
logger.info("Model loaded successfully")
return model, processor, transform
def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0):
"""Samples exactly 16 frames uniformly from the video."""
if seg_len <= clip_len:
indices = np.linspace(0, seg_len - 1, num=clip_len, dtype=int)
else:
end_idx = np.random.randint(clip_len, seg_len)
start_idx = max(0, end_idx - clip_len)
indices = np.linspace(start_idx, end_idx - 1, num=clip_len, dtype=int)
return np.clip(indices, 0, seg_len - 1)
def process_video(video_path):
try:
container = av.open(video_path)
video_stream = container.streams.video[0]
seg_len = video_stream.frames if video_stream.frames > 0 else int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT))
except Exception as e:
logger.error(f"Error opening video: {str(e)}")
return None, None
indices = sample_frame_indices(clip_len=16, seg_len=seg_len)
frames = []
try:
container.seek(0)
for i, frame in enumerate(container.decode(video=0)):
if i > indices[-1]:
break
if i in indices:
frames.append(frame.to_ndarray(format="rgb24"))
except Exception as e:
logger.error(f"Error decoding video with PyAV: {str(e)}")
if not frames:
logger.info("Falling back to OpenCV for frame extraction")
cap = cv2.VideoCapture(video_path)
for i in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
if len(frames) != 16:
logger.error(f"Could not extract 16 frames, got {len(frames)}")
return None, None
return np.stack(frames), indices
def predict_video(frames):
"""Processes frames and runs VideoMAE classification."""
model, processor, transform = load_model()
video_tensor = torch.stack([transform(Image.fromarray(frame)) for frame in frames])
video_tensor = video_tensor.unsqueeze(0) # Add batch dimension
inputs = processor(list(video_tensor[0]), return_tensors="pt", do_rescale=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad(): # Disable gradient calculation for inference
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
id2label = model.config.id2label
return id2label.get(predicted_class, "Unknown")
@app.route('/classify-video', methods=['POST'])
def classify_video():
if 'video' not in request.files:
logger.warning("No video file in request")
return jsonify({'error': 'No video file provided'}), 400
video_file = request.files['video']
if video_file.filename == '':
logger.warning("Empty video filename")
return jsonify({'error': 'No video selected'}), 400
# Create temporary directory
temp_dir = tempfile.mkdtemp()
video_path = os.path.join(temp_dir, video_file.filename)
try:
# Save the uploaded video
logger.info(f"Saving uploaded video to {video_path}")
video_file.save(video_path)
# Process the video
logger.info("Processing video...")
frames, indices = process_video(video_path)
if frames is None:
return jsonify({'error': 'Failed to process video file'}), 400
# Get the prediction
logger.info("Running prediction...")
prediction = predict_video(frames)
logger.info(f"Prediction result: {prediction}")
return jsonify({'prediction': prediction})
except Exception as e:
logger.exception(f"Error processing video: {str(e)}")
return jsonify({'error': f'Error processing video: {str(e)}'}), 500
finally:
# Clean up the temporary directory and its contents
if os.path.exists(temp_dir):
logger.info(f"Cleaning up temporary directory: {temp_dir}")
shutil.rmtree(temp_dir)
@app.route('/health', methods=['GET'])
def health_check():
"""Endpoint to check if the service is up and running"""
return jsonify({"status": "healthy"}), 200
if __name__ == '__main__':
# Load model at startup
logger.info("Initializing application...")
load_model()
# Get port from environment variable or use 5000 as default
port = int(os.environ.get('PORT', 7860))
logger.info(f"Starting Flask application on port {port}")
app.run(host='0.0.0.0', port=port, debug=False)