rajendrakumarv tejovk311 commited on
Commit
c224e46
·
verified ·
1 Parent(s): c03f135

Update app.py (#7)

Browse files

- Update app.py (bfad6f40839370b37f6622329044c07d2100e11a)


Co-authored-by: Kattamuri Tejo Vardhan <tejovk311@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +55 -68
app.py CHANGED
@@ -1,5 +1,10 @@
1
- from flask import Flask, request, jsonify
2
  import os
 
 
 
 
 
 
3
  import numpy as np
4
  import torch
5
  import av
@@ -10,36 +15,40 @@ import logging
10
  from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
11
  from PIL import Image
12
  from torchvision.transforms import Compose, Resize, ToTensor
13
- os.makedirs("./.cache", exist_ok=True)
14
 
 
15
  app = Flask(__name__)
 
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
- # Global variables to store model and processor
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  model = None
23
  processor = None
24
  transform = None
25
 
 
26
  def load_model():
27
- """Load the model and processor"""
28
  global model, processor, transform
29
  if model is None:
30
  model_name = "OPear/videomae-large-finetuned-UCF-Crime"
31
- logger.info(f"Loading model {model_name} on {device}...")
32
- model = VideoMAEForVideoClassification.from_pretrained("OPear/videomae-large-finetuned-UCF-Crime",cache_dir="./.cache").to(device)
 
33
  processor = VideoMAEImageProcessor.from_pretrained(model_name)
34
  transform = Compose([
35
  Resize((224, 224)),
36
  ToTensor(),
37
  ])
38
- logger.info("Model loaded successfully")
39
  return model, processor, transform
40
 
 
41
  def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0):
42
- """Samples exactly 16 frames uniformly from the video."""
43
  if seg_len <= clip_len:
44
  indices = np.linspace(0, seg_len - 1, num=clip_len, dtype=int)
45
  else:
@@ -48,18 +57,23 @@ def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0):
48
  indices = np.linspace(start_idx, end_idx - 1, num=clip_len, dtype=int)
49
  return np.clip(indices, 0, seg_len - 1)
50
 
 
51
  def process_video(video_path):
 
52
  try:
53
  container = av.open(video_path)
54
  video_stream = container.streams.video[0]
55
- seg_len = video_stream.frames if video_stream.frames > 0 else int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT))
 
 
56
  except Exception as e:
57
- logger.error(f"Error opening video: {str(e)}")
58
  return None, None
59
-
60
  indices = sample_frame_indices(clip_len=16, seg_len=seg_len)
61
  frames = []
62
 
 
63
  try:
64
  container.seek(0)
65
  for i, frame in enumerate(container.decode(video=0)):
@@ -68,101 +82,74 @@ def process_video(video_path):
68
  if i in indices:
69
  frames.append(frame.to_ndarray(format="rgb24"))
70
  except Exception as e:
71
- logger.error(f"Error decoding video with PyAV: {str(e)}")
72
 
73
- if not frames:
74
- logger.info("Falling back to OpenCV for frame extraction")
75
  cap = cv2.VideoCapture(video_path)
76
  for i in indices:
77
  cap.set(cv2.CAP_PROP_POS_FRAMES, i)
78
  ret, frame = cap.read()
79
  if ret:
80
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
81
- frames.append(frame)
82
  cap.release()
83
 
84
  if len(frames) != 16:
85
- logger.error(f"Could not extract 16 frames, got {len(frames)}")
86
  return None, None
87
 
88
  return np.stack(frames), indices
89
 
 
90
  def predict_video(frames):
91
- """Processes frames and runs VideoMAE classification."""
92
  model, processor, transform = load_model()
93
-
94
- video_tensor = torch.stack([transform(Image.fromarray(frame)) for frame in frames])
95
- video_tensor = video_tensor.unsqueeze(0) # Add batch dimension
96
 
97
  inputs = processor(list(video_tensor[0]), return_tensors="pt", do_rescale=False)
98
  inputs = {k: v.to(device) for k, v in inputs.items()}
99
 
100
- with torch.no_grad(): # Disable gradient calculation for inference
101
  outputs = model(**inputs)
102
-
103
  logits = outputs.logits
104
- predicted_class = logits.argmax(-1).item()
 
105
 
106
- id2label = model.config.id2label
107
- return id2label.get(predicted_class, "Unknown")
108
 
109
  @app.route('/classify-video', methods=['POST'])
110
  def classify_video():
111
  if 'video' not in request.files:
112
- logger.warning("No video file in request")
113
  return jsonify({'error': 'No video file provided'}), 400
114
-
115
- video_file = request.files['video']
116
-
117
- if video_file.filename == '':
118
- logger.warning("Empty video filename")
119
- return jsonify({'error': 'No video selected'}), 400
120
-
121
- # Create temporary directory
122
  temp_dir = tempfile.mkdtemp()
123
- video_path = os.path.join(temp_dir, video_file.filename)
124
-
125
  try:
126
- # Save the uploaded video
127
- logger.info(f"Saving uploaded video to {video_path}")
128
- video_file.save(video_path)
129
-
130
- # Process the video
131
- logger.info("Processing video...")
132
- frames, indices = process_video(video_path)
133
-
134
  if frames is None:
135
- return jsonify({'error': 'Failed to process video file'}), 400
136
-
137
- # Get the prediction
138
- logger.info("Running prediction...")
139
  prediction = predict_video(frames)
140
-
141
- logger.info(f"Prediction result: {prediction}")
142
  return jsonify({'prediction': prediction})
143
-
144
  except Exception as e:
145
- logger.exception(f"Error processing video: {str(e)}")
146
- return jsonify({'error': f'Error processing video: {str(e)}'}), 500
147
-
148
  finally:
149
- # Clean up the temporary directory and its contents
150
- if os.path.exists(temp_dir):
151
- logger.info(f"Cleaning up temporary directory: {temp_dir}")
152
- shutil.rmtree(temp_dir)
153
 
154
  @app.route('/health', methods=['GET'])
155
  def health_check():
156
- """Endpoint to check if the service is up and running"""
157
- return jsonify({"status": "healthy"}), 200
158
 
159
  if __name__ == '__main__':
160
- # Load model at startup
161
- logger.info("Initializing application...")
162
  load_model()
163
-
164
- # Get port from environment variable or use 5000 as default
165
  port = int(os.environ.get('PORT', 7860))
166
-
167
- logger.info(f"Starting Flask application on port {port}")
168
- app.run(host='0.0.0.0', port=port, debug=False)
 
 
1
  import os
2
+ # Configure Hugging Face caches to use the writable /cache volume in Spaces
3
+ os.environ["HF_HOME"] = "/cache"
4
+ os.environ["TRANSFORMERS_CACHE"] = "/cache"
5
+ os.environ["HF_DATASETS_CACHE"] = "/cache"
6
+
7
+ from flask import Flask, request, jsonify
8
  import numpy as np
9
  import torch
10
  import av
 
15
  from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
16
  from PIL import Image
17
  from torchvision.transforms import Compose, Resize, ToTensor
 
18
 
19
+ # Initialize Flask app
20
  app = Flask(__name__)
21
+
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
+ # Globals for model, processor, and transforms
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  model = None
29
  processor = None
30
  transform = None
31
 
32
+
33
  def load_model():
34
+ """Load the model and processor into globals"""
35
  global model, processor, transform
36
  if model is None:
37
  model_name = "OPear/videomae-large-finetuned-UCF-Crime"
38
+ logger.info(f"Loading model {model_name} on device {device}")
39
+ # Downloads will go to /cache automatically
40
+ model = VideoMAEForVideoClassification.from_pretrained(model_name).to(device)
41
  processor = VideoMAEImageProcessor.from_pretrained(model_name)
42
  transform = Compose([
43
  Resize((224, 224)),
44
  ToTensor(),
45
  ])
46
+ logger.info("Model and processor loaded successfully")
47
  return model, processor, transform
48
 
49
+
50
  def sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=0):
51
+ """Uniformly sample exactly 16 frame indices from a clip"""
52
  if seg_len <= clip_len:
53
  indices = np.linspace(0, seg_len - 1, num=clip_len, dtype=int)
54
  else:
 
57
  indices = np.linspace(start_idx, end_idx - 1, num=clip_len, dtype=int)
58
  return np.clip(indices, 0, seg_len - 1)
59
 
60
+
61
  def process_video(video_path):
62
+ """Extract 16 uniformly-sampled frames from the video"""
63
  try:
64
  container = av.open(video_path)
65
  video_stream = container.streams.video[0]
66
+ seg_len = video_stream.frames if video_stream.frames > 0 else int(
67
+ cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT)
68
+ )
69
  except Exception as e:
70
+ logger.error(f"Error opening video: {e}")
71
  return None, None
72
+
73
  indices = sample_frame_indices(clip_len=16, seg_len=seg_len)
74
  frames = []
75
 
76
+ # Try PyAV decode
77
  try:
78
  container.seek(0)
79
  for i, frame in enumerate(container.decode(video=0)):
 
82
  if i in indices:
83
  frames.append(frame.to_ndarray(format="rgb24"))
84
  except Exception as e:
85
+ logger.warning(f"PyAV decoding failed, falling back to OpenCV: {e}")
86
 
87
+ # Fallback to OpenCV if necessary
88
+ if len(frames) < len(indices):
89
  cap = cv2.VideoCapture(video_path)
90
  for i in indices:
91
  cap.set(cv2.CAP_PROP_POS_FRAMES, i)
92
  ret, frame = cap.read()
93
  if ret:
94
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
 
95
  cap.release()
96
 
97
  if len(frames) != 16:
98
+ logger.error(f"Expected 16 frames, got {len(frames)}")
99
  return None, None
100
 
101
  return np.stack(frames), indices
102
 
103
+
104
  def predict_video(frames):
105
+ """Run inference on a stack of 16 frames"""
106
  model, processor, transform = load_model()
107
+ video_tensor = torch.stack([transform(Image.fromarray(f)) for f in frames])
108
+ video_tensor = video_tensor.unsqueeze(0)
 
109
 
110
  inputs = processor(list(video_tensor[0]), return_tensors="pt", do_rescale=False)
111
  inputs = {k: v.to(device) for k, v in inputs.items()}
112
 
113
+ with torch.no_grad():
114
  outputs = model(**inputs)
 
115
  logits = outputs.logits
116
+ pred_id = logits.argmax(-1).item()
117
+ return model.config.id2label.get(pred_id, "Unknown")
118
 
 
 
119
 
120
  @app.route('/classify-video', methods=['POST'])
121
  def classify_video():
122
  if 'video' not in request.files:
 
123
  return jsonify({'error': 'No video file provided'}), 400
124
+
125
+ file = request.files['video']
126
+ if file.filename == '':
127
+ return jsonify({'error': 'Empty filename'}), 400
128
+
 
 
 
129
  temp_dir = tempfile.mkdtemp()
130
+ path = os.path.join(temp_dir, file.filename)
 
131
  try:
132
+ file.save(path)
133
+ frames, _ = process_video(path)
 
 
 
 
 
 
134
  if frames is None:
135
+ return jsonify({'error': 'Failed to extract frames'}), 400
 
 
 
136
  prediction = predict_video(frames)
 
 
137
  return jsonify({'prediction': prediction})
 
138
  except Exception as e:
139
+ logger.exception(f"Error during processing: {e}")
140
+ return jsonify({'error': str(e)}), 500
 
141
  finally:
142
+ shutil.rmtree(temp_dir, ignore_errors=True)
143
+
 
 
144
 
145
  @app.route('/health', methods=['GET'])
146
  def health_check():
147
+ return jsonify({'status': 'healthy'}), 200
148
+
149
 
150
  if __name__ == '__main__':
151
+ # Preload model on startup
152
+ logger.info("Starting application and loading model...")
153
  load_model()
 
 
154
  port = int(os.environ.get('PORT', 7860))
155
+ app.run(host='0.0.0.0', port=port, debug=False)