MNGames commited on
Commit
29cb7aa
1 Parent(s): e7b04ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -62
app.py CHANGED
@@ -1,68 +1,77 @@
1
  import gradio as gr
2
- from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
3
  import torch
4
- import cv2 # OpenCV for video processing
 
 
5
 
6
- # Model ID for video classification (UCF101 subset)
7
- model_id = "MCG-NJU/videomae-base"
8
 
9
- # Parameters for frame extraction
10
- TARGET_FRAME_COUNT = 16
11
- FRAME_SIZE = (224, 224) # Expected frame size for the model
12
 
13
  def analyze_video(video):
14
  # Extract key frames from the video using OpenCV
15
- frames = extract_key_frames(video, TARGET_FRAME_COUNT)
16
-
17
- # Resize frames to the expected size
18
- frames = [cv2.resize(frame, FRAME_SIZE) for frame in frames]
19
 
20
- # Load model and feature extractor manually
21
- model = VideoMAEForVideoClassification.from_pretrained(model_id)
22
- processor = VideoMAEImageProcessor.from_pretrained(model_id)
23
 
24
- # Prepare frames for the model
25
  inputs = processor(images=frames, return_tensors="pt")
26
 
27
- # Make predictions
28
  with torch.no_grad():
29
- outputs = model(**inputs)
30
 
31
  logits = outputs.logits
32
  predictions = torch.argmax(logits, dim=-1)
33
 
34
- # Analyze predictions for insights related to the play
35
- results = []
36
- for prediction in predictions:
37
- result = analyze_predictions_ucf101(prediction.item())
38
- results.append(result)
 
 
 
 
 
 
 
39
 
40
- # Aggregate results across frames and provide a final analysis
41
- final_result = aggregate_results(results)
42
 
43
  return final_result
44
 
45
- def extract_key_frames(video, target_frame_count):
46
  cap = cv2.VideoCapture(video)
47
  frames = []
48
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
49
-
50
- # Calculate interval for frame extraction
51
- interval = max(1, frame_count // target_frame_count)
52
-
53
- for i in range(0, frame_count, interval):
54
- cap.set(cv2.CAP_PROP_POS_FRAMES, i)
55
  ret, frame = cap.read()
56
- if ret:
57
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Convert to RGB
58
- if len(frames) >= target_frame_count:
59
- break
60
-
61
  cap.release()
62
  return frames
63
 
64
- def analyze_predictions_ucf101(prediction):
65
- # Map prediction to action labels (this mapping is hypothetical)
 
 
 
 
 
 
 
 
 
 
 
 
66
  action_labels = {
67
  0: "running",
68
  1: "sliding",
@@ -70,37 +79,20 @@ def analyze_predictions_ucf101(prediction):
70
  # Add more labels as necessary
71
  }
72
  action = action_labels.get(prediction, "unknown")
73
-
74
- relevant_actions = ["running", "sliding", "jumping"]
75
- if action in relevant_actions:
76
- if action == "sliding":
77
- return "potentially safe"
78
- elif action == "running":
79
- return "potentially out"
80
- else:
81
- return "inconclusive"
82
- else:
83
- return "inconclusive"
84
 
85
  def aggregate_results(results):
86
- # Combine insights from analyzing each frame (e.g., dominant action classes, confidence scores)
87
- safe_count = results.count("potentially safe")
88
- out_count = results.count("potentially out")
89
-
90
- if safe_count > out_count:
91
- return "Safe"
92
- elif out_count > safe_count:
93
- return "Out"
94
- else:
95
- return "Inconclusive"
96
 
97
  # Gradio interface
98
  interface = gr.Interface(
99
  fn=analyze_video,
100
  inputs="video",
101
  outputs="text",
102
- title="Baseball Play Analysis (UCF101 Subset Exploration)",
103
- description="Upload a video of a baseball play (safe/out at a base). This app explores using a video classification model (UCF101 subset) for analysis. Note: The model might not be specifically trained for baseball plays."
104
  )
105
 
106
- interface.launch(share=True)
 
1
  import gradio as gr
 
2
  import torch
3
+ import cv2
4
+ import numpy as np
5
+ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
6
 
7
+ # Model IDs for video classification (UCF101 subset)
8
+ classification_model_id = "MCG-NJU/videomae-base"
9
 
10
+ # Object detection model (you can replace this with a more accurate one if needed)
11
+ object_detection_model = "yolov5s"
 
12
 
13
  def analyze_video(video):
14
  # Extract key frames from the video using OpenCV
15
+ frames = extract_key_frames(video)
 
 
 
16
 
17
+ # Load classification model and image processor
18
+ classification_model = VideoMAEForVideoClassification.from_pretrained(classification_model_id)
19
+ processor = VideoMAEImageProcessor.from_pretrained(classification_model_id)
20
 
21
+ # Prepare frames for the classification model
22
  inputs = processor(images=frames, return_tensors="pt")
23
 
24
+ # Make predictions using the classification model
25
  with torch.no_grad():
26
+ outputs = classification_model(**inputs)
27
 
28
  logits = outputs.logits
29
  predictions = torch.argmax(logits, dim=-1)
30
 
31
+ # Object detection and tracking (ball and baseman)
32
+ object_detection_results = []
33
+ for frame in frames:
34
+ ball_position = detect_object(frame, "ball")
35
+ baseman_position = detect_object(frame, "baseman")
36
+ object_detection_results.append((ball_position, baseman_position))
37
+
38
+ # Analyze predictions and object detection results
39
+ analysis_results = []
40
+ for prediction, (ball_position, baseman_position) in zip(predictions, object_detection_results):
41
+ result = analyze_frame(prediction.item(), ball_position, baseman_position)
42
+ analysis_results.append(result)
43
 
44
+ # Aggregate analysis results
45
+ final_result = aggregate_results(analysis_results)
46
 
47
  return final_result
48
 
49
+ def extract_key_frames(video):
50
  cap = cv2.VideoCapture(video)
51
  frames = []
52
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
53
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
54
+ for i in range(frame_count):
 
 
 
 
55
  ret, frame = cap.read()
56
+ if ret and i % (fps // 2) == 0: # Extract a frame every half second
57
+ frames.append(frame)
 
 
 
58
  cap.release()
59
  return frames
60
 
61
+ def detect_object(frame, object_type):
62
+ # Placeholder function for object detection (replace with actual implementation)
63
+ # Here, we assume that the object is detected at the center of the frame
64
+ h, w, _ = frame.shape
65
+ if object_type == "ball":
66
+ return (w // 2, h // 2) # Return center coordinates for the ball
67
+ elif object_type == "baseman":
68
+ return (w // 2, h // 2) # Return center coordinates for the baseman
69
+ else:
70
+ return None
71
+
72
+ def analyze_frame(prediction, ball_position, baseman_position):
73
+ # Placeholder function for analyzing a single frame
74
+ # You can replace this with actual logic based on your requirements
75
  action_labels = {
76
  0: "running",
77
  1: "sliding",
 
79
  # Add more labels as necessary
80
  }
81
  action = action_labels.get(prediction, "unknown")
82
+ return {"action": action, "ball_position": ball_position, "baseman_position": baseman_position}
 
 
 
 
 
 
 
 
 
 
83
 
84
  def aggregate_results(results):
85
+ # Placeholder function for aggregating analysis results
86
+ # You can implement this based on your specific requirements
87
+ return results
 
 
 
 
 
 
 
88
 
89
  # Gradio interface
90
  interface = gr.Interface(
91
  fn=analyze_video,
92
  inputs="video",
93
  outputs="text",
94
+ title="Baseball Play Analysis",
95
+ description="Upload a video of a baseball play to analyze.",
96
  )
97
 
98
+ interface.launch()