import gradio as gr import torch import cv2 import numpy as np from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor # Model IDs for video classification (UCF101 subset) classification_model_id = "MCG-NJU/videomae-base" # Object detection model (you can replace this with a more accurate one if needed) object_detection_model = "yolov5s" # Parameters for frame extraction TARGET_FRAME_COUNT = 16 FRAME_SIZE = (224, 224) # Expected frame size for the model def analyze_video(video): # Extract key frames from the video using OpenCV frames = extract_key_frames(video) # Load classification model and image processor classification_model = VideoMAEForVideoClassification.from_pretrained(classification_model_id) processor = VideoMAEImageProcessor.from_pretrained(classification_model_id) # Prepare frames for the classification model inputs = processor(images=frames, return_tensors="pt") # Make predictions using the classification model with torch.no_grad(): outputs = classification_model(**inputs) logits = outputs.logits predictions = torch.argmax(logits, dim=-1) # Object detection and tracking (ball and baseman) object_detection_results = [] for frame in frames: ball_position = detect_object(frame, "ball") baseman_position = detect_object(frame, "baseman") object_detection_results.append((ball_position, baseman_position)) # Analyze predictions and object detection results analysis_results = [] for prediction, (ball_position, baseman_position) in zip(predictions, object_detection_results): result = analyze_frame(prediction.item(), ball_position, baseman_position) analysis_results.append(result) # Aggregate analysis results final_result = aggregate_results(analysis_results) return final_result def extract_key_frames(video): cap = cv2.VideoCapture(video) frames = [] frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) interval = max(1, frame_count // TARGET_FRAME_COUNT) for i in range(frame_count): ret, frame = cap.read() if ret and i % interval == 0: # Extract frames at regular intervals frame = cv2.resize(frame, FRAME_SIZE) # Resize frame frames.append(frame) cap.release() return frames def detect_object(frame, object_type): # Placeholder function for object detection (replace with actual implementation) # Here, we assume that the object is detected at the center of the frame h, w, _ = frame.shape if object_type == "ball": return (w // 2, h // 2) # Return center coordinates for the ball elif object_type == "baseman": return (w // 2, h // 2) # Return center coordinates for the baseman else: return None def analyze_frame(prediction, ball_position, baseman_position): # Placeholder function for analyzing a single frame # You can replace this with actual logic based on your requirements action_labels = { 0: "running", 1: "sliding", 2: "jumping", # Add more labels as necessary } action = action_labels.get(prediction, "unknown") return {"action": action, "ball_position": ball_position, "baseman_position": baseman_position} def aggregate_results(results): # Placeholder function for aggregating analysis results # You can implement this based on your specific requirements return results # Gradio interface interface = gr.Interface( fn=analyze_video, inputs="video", outputs="text", title="Baseball Play Analysis", description="Upload a video of a baseball play to analyze.", ) interface.launch()