|
import torch |
|
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor |
|
import os, cv2, uuid, json |
|
import numpy as np |
|
import gdown |
|
|
|
model_path = "checkpoint_epoch_1.pt" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if not os.path.exists(model_path): |
|
print("Downloading checkpoint...") |
|
url = "https://drive.google.com/uc?id=1dIaptYPq-1fgo0yoBoPlDsbIfs3BEqJI" |
|
gdown.download(url, model_path, quiet=False) |
|
|
|
model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base", num_labels=3) |
|
checkpoint = torch.load(model_path, map_location=device) |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
model.eval().to(device) |
|
|
|
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base") |
|
label_map = {0: "Goal", 1: "Card", 2: "Substitution"} |
|
|
|
def predict_gradio(video): |
|
import tempfile |
|
import shutil |
|
|
|
|
|
video_id = str(uuid.uuid4()) |
|
work_dir = f"./temp/{video_id}" |
|
os.makedirs(work_dir, exist_ok=True) |
|
|
|
|
|
temp_video_path = os.path.join(work_dir, "input.mp4") |
|
if isinstance(video, str): |
|
|
|
shutil.copy(video, temp_video_path) |
|
else: |
|
|
|
with open(temp_video_path, "wb") as f: |
|
f.write(video.read()) |
|
|
|
|
|
cap = cv2.VideoCapture(temp_video_path) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
if fps == 0 or fps != fps: |
|
return [{"error": "Invalid or unreadable video."}], "" |
|
|
|
frames = [] |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
resized = cv2.resize(frame, (224, 224)) |
|
frames.append(resized) |
|
cap.release() |
|
|
|
segment_size = int(fps * 5) |
|
predictions = [] |
|
output_segments = [] |
|
|
|
for i in range(0, len(frames), segment_size): |
|
segment = frames[i:i+segment_size] |
|
if len(segment) < 16: |
|
continue |
|
indices = np.linspace(0, len(segment)-1, 16).astype(int) |
|
sampled_frames = [segment[idx] for idx in indices] |
|
|
|
inputs = feature_extractor(sampled_frames, return_tensors="pt") |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1) |
|
confidence, pred = torch.max(probs, dim=1) |
|
|
|
if confidence.item() > 0.70: |
|
label = label_map[pred.item()] |
|
start_time = i / fps |
|
end_time = min((i + segment_size), len(frames)) / fps |
|
predictions.append({ |
|
"start": round(start_time, 2), |
|
"end": round(end_time, 2), |
|
"label": label, |
|
"confidence": round(confidence.item(), 3) |
|
}) |
|
output_segments.append(segment) |
|
|
|
out_path = f"{work_dir}/summary.mp4" |
|
if output_segments: |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
out = cv2.VideoWriter(out_path, fourcc, fps, (224, 224)) |
|
for seg in output_segments: |
|
for frame in seg: |
|
out.write(frame) |
|
out.release() |
|
return predictions, out_path |
|
else: |
|
return predictions, "" |
|
|
|
|