MBase / app.py
MNGames's picture
Update app.py
c9b0a28 verified
raw
history blame contribute delete
No virus
3.76 kB
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
# Model IDs for video classification (UCF101 subset)
classification_model_id = "MCG-NJU/videomae-base"
# Object detection model (you can replace this with a more accurate one if needed)
object_detection_model = "yolov5s"
# Parameters for frame extraction
TARGET_FRAME_COUNT = 16
FRAME_SIZE = (224, 224) # Expected frame size for the model
def analyze_video(video):
# Extract key frames from the video using OpenCV
frames = extract_key_frames(video)
# Load classification model and image processor
classification_model = VideoMAEForVideoClassification.from_pretrained(classification_model_id)
processor = VideoMAEImageProcessor.from_pretrained(classification_model_id)
# Prepare frames for the classification model
inputs = processor(images=frames, return_tensors="pt")
# Make predictions using the classification model
with torch.no_grad():
outputs = classification_model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
# Object detection and tracking (ball and baseman)
object_detection_results = []
for frame in frames:
ball_position = detect_object(frame, "ball")
baseman_position = detect_object(frame, "baseman")
object_detection_results.append((ball_position, baseman_position))
# Analyze predictions and object detection results
analysis_results = []
for prediction, (ball_position, baseman_position) in zip(predictions, object_detection_results):
result = analyze_frame(prediction.item(), ball_position, baseman_position)
analysis_results.append(result)
# Aggregate analysis results
final_result = aggregate_results(analysis_results)
return final_result
def extract_key_frames(video):
cap = cv2.VideoCapture(video)
frames = []
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
interval = max(1, frame_count // TARGET_FRAME_COUNT)
for i in range(frame_count):
ret, frame = cap.read()
if ret and i % interval == 0: # Extract frames at regular intervals
frame = cv2.resize(frame, FRAME_SIZE) # Resize frame
frames.append(frame)
cap.release()
return frames
def detect_object(frame, object_type):
# Placeholder function for object detection (replace with actual implementation)
# Here, we assume that the object is detected at the center of the frame
h, w, _ = frame.shape
if object_type == "ball":
return (w // 2, h // 2) # Return center coordinates for the ball
elif object_type == "baseman":
return (w // 2, h // 2) # Return center coordinates for the baseman
else:
return None
def analyze_frame(prediction, ball_position, baseman_position):
# Placeholder function for analyzing a single frame
# You can replace this with actual logic based on your requirements
action_labels = {
0: "running",
1: "sliding",
2: "jumping",
# Add more labels as necessary
}
action = action_labels.get(prediction, "unknown")
return {"action": action, "ball_position": ball_position, "baseman_position": baseman_position}
def aggregate_results(results):
# Placeholder function for aggregating analysis results
# You can implement this based on your specific requirements
return results
# Gradio interface
interface = gr.Interface(
fn=analyze_video,
inputs="video",
outputs="text",
title="Baseball Play Analysis",
description="Upload a video of a baseball play to analyze.",
)
interface.launch()