File size: 5,796 Bytes
3ba492a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from flask import Flask, request, jsonify
import os
import numpy as np
import torch
import av
import cv2
import tempfile
import shutil
import logging
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor

app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables to store model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
processor = None
transform = None

def load_model():
    """Load the model and processor"""
    global model, processor, transform
    if model is None:
        model_name = "OPear/videomae-large-finetuned-UCF-Crime"
        logger.info(f"Loading model {model_name} on {device}...")
        model = VideoMAEForVideoClassification.from_pretrained(model_name).to(device)
        processor = VideoMAEImageProcessor.from_pretrained(model_name)
        transform = Compose([
            Resize((224, 224)),
            ToTensor(),
        ])
        logger.info("Model loaded successfully")
    return model, processor, transform

def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0):
    """Samples exactly 16 frames uniformly from the video."""
    if seg_len <= clip_len:
        indices = np.linspace(0, seg_len - 1, num=clip_len, dtype=int)
    else:
        end_idx = np.random.randint(clip_len, seg_len)
        start_idx = max(0, end_idx - clip_len)
        indices = np.linspace(start_idx, end_idx - 1, num=clip_len, dtype=int)
    return np.clip(indices, 0, seg_len - 1)

def process_video(video_path):
    try:
        container = av.open(video_path)
        video_stream = container.streams.video[0]
        seg_len = video_stream.frames if video_stream.frames > 0 else int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT))
    except Exception as e:
        logger.error(f"Error opening video: {str(e)}")
        return None, None
    
    indices = sample_frame_indices(clip_len=16, seg_len=seg_len)
    frames = []

    try:
        container.seek(0)
        for i, frame in enumerate(container.decode(video=0)):
            if i > indices[-1]:
                break
            if i in indices:
                frames.append(frame.to_ndarray(format="rgb24"))
    except Exception as e:
        logger.error(f"Error decoding video with PyAV: {str(e)}")

    if not frames:
        logger.info("Falling back to OpenCV for frame extraction")
        cap = cv2.VideoCapture(video_path)
        for i in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
        cap.release()

    if len(frames) != 16:
        logger.error(f"Could not extract 16 frames, got {len(frames)}")
        return None, None

    return np.stack(frames), indices

def predict_video(frames):
    """Processes frames and runs VideoMAE classification."""
    model, processor, transform = load_model()
    
    video_tensor = torch.stack([transform(Image.fromarray(frame)) for frame in frames])
    video_tensor = video_tensor.unsqueeze(0)  # Add batch dimension

    inputs = processor(list(video_tensor[0]), return_tensors="pt", do_rescale=False)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():  # Disable gradient calculation for inference
        outputs = model(**inputs)
    
    logits = outputs.logits
    predicted_class = logits.argmax(-1).item()

    id2label = model.config.id2label
    return id2label.get(predicted_class, "Unknown")

@app.route('/classify-video', methods=['POST'])
def classify_video():
    if 'video' not in request.files:
        logger.warning("No video file in request")
        return jsonify({'error': 'No video file provided'}), 400
    
    video_file = request.files['video']
    
    if video_file.filename == '':
        logger.warning("Empty video filename")
        return jsonify({'error': 'No video selected'}), 400
    
    # Create temporary directory
    temp_dir = tempfile.mkdtemp()
    video_path = os.path.join(temp_dir, video_file.filename)
    
    try:
        # Save the uploaded video
        logger.info(f"Saving uploaded video to {video_path}")
        video_file.save(video_path)
        
        # Process the video
        logger.info("Processing video...")
        frames, indices = process_video(video_path)
        
        if frames is None:
            return jsonify({'error': 'Failed to process video file'}), 400
        
        # Get the prediction
        logger.info("Running prediction...")
        prediction = predict_video(frames)
        
        logger.info(f"Prediction result: {prediction}")
        return jsonify({'prediction': prediction})
    
    except Exception as e:
        logger.exception(f"Error processing video: {str(e)}")
        return jsonify({'error': f'Error processing video: {str(e)}'}), 500
    
    finally:
        # Clean up the temporary directory and its contents
        if os.path.exists(temp_dir):
            logger.info(f"Cleaning up temporary directory: {temp_dir}")
            shutil.rmtree(temp_dir)

@app.route('/health', methods=['GET'])
def health_check():
    """Endpoint to check if the service is up and running"""
    return jsonify({"status": "healthy"}), 200

if __name__ == '__main__':
    # Load model at startup
    logger.info("Initializing application...")
    load_model()
    
    # Get port from environment variable or use 5000 as default
    port = int(os.environ.get('PORT', 7860))
    
    logger.info(f"Starting Flask application on port {port}")
    app.run(host='0.0.0.0', port=port, debug=False)