|
import gradio as gr |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor |
|
|
|
|
|
classification_model_id = "MCG-NJU/videomae-base" |
|
|
|
|
|
object_detection_model = "yolov5s" |
|
TARGET_FRAME_COUNT = 16 |
|
FRAME_SIZE = (224, 224) |
|
def analyze_video(video): |
|
|
|
frames = extract_key_frames(video) |
|
|
|
|
|
classification_model = VideoMAEForVideoClassification.from_pretrained(classification_model_id) |
|
processor = VideoMAEImageProcessor.from_pretrained(classification_model_id) |
|
|
|
|
|
inputs = processor(images=frames, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = classification_model(**inputs) |
|
|
|
logits = outputs.logits |
|
predictions = torch.argmax(logits, dim=-1) |
|
|
|
|
|
object_detection_results = [] |
|
for frame in frames: |
|
ball_position = detect_object(frame, "ball") |
|
baseman_position = detect_object(frame, "baseman") |
|
object_detection_results.append((ball_position, baseman_position)) |
|
|
|
|
|
analysis_results = [] |
|
for prediction, (ball_position, baseman_position) in zip(predictions, object_detection_results): |
|
result = analyze_frame(prediction.item(), ball_position, baseman_position) |
|
analysis_results.append(result) |
|
|
|
|
|
final_result = aggregate_results(analysis_results) |
|
|
|
return final_result |
|
|
|
def extract_key_frames(video): |
|
cap = cv2.VideoCapture(video) |
|
frames = [] |
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
for i in range(frame_count): |
|
ret, frame = cap.read() |
|
if ret and i % (fps // 2) == 0: |
|
frames.append(frame) |
|
cap.release() |
|
return frames |
|
|
|
def detect_object(frame, object_type): |
|
|
|
|
|
h, w, _ = frame.shape |
|
if object_type == "ball": |
|
return (w // 2, h // 2) |
|
elif object_type == "baseman": |
|
return (w // 2, h // 2) |
|
else: |
|
return None |
|
|
|
def analyze_frame(prediction, ball_position, baseman_position): |
|
|
|
|
|
action_labels = { |
|
0: "running", |
|
1: "sliding", |
|
2: "jumping", |
|
|
|
} |
|
action = action_labels.get(prediction, "unknown") |
|
return {"action": action, "ball_position": ball_position, "baseman_position": baseman_position} |
|
|
|
def aggregate_results(results): |
|
|
|
|
|
return results |
|
|
|
|
|
interface = gr.Interface( |
|
fn=analyze_video, |
|
inputs="video", |
|
outputs="text", |
|
title="Baseball Play Analysis", |
|
description="Upload a video of a baseball play to analyze.", |
|
) |
|
|
|
interface.launch() |
|
|