Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from pytorchvideo.data import make_clip_sampler | |
from pytorchvideo.data.clip_sampling import ClipInfoList | |
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV | |
from pytorchvideo.data.video import VideoPathHandler | |
from pytorchvideo.transforms import ( | |
Normalize, | |
UniformTemporalSubsample, RandomShortSideScale, | |
) | |
from torchvision.transforms import ( | |
Compose, | |
Lambda, | |
Resize, RandomCrop, | |
) | |
from transformers import pipeline | |
MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-8-batch-8000-vid-multiclass_1697155188" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
CLIPS_FROM_SINGLE_VIDEO = 5 | |
pipe = pipeline("video-classification", model=MODEL_CKPT) | |
trained_model = pipe.model | |
image_processor = pipe.image_processor | |
mean = image_processor.image_mean | |
std = image_processor.image_std | |
if "shortest_edge" in image_processor.size: | |
height = width = image_processor.size["shortest_edge"] | |
else: | |
height = image_processor.size["height"] | |
width = image_processor.size["width"] | |
resize_to = (height, width) | |
num_frames_to_sample = trained_model.config.num_frames | |
sample_rate = 4 | |
fps = 30 | |
clip_duration = num_frames_to_sample * sample_rate / fps | |
# Validation and Test datasets' transformations. | |
inference_transform = Compose( | |
[ | |
UniformTemporalSubsample(num_frames_to_sample), | |
Lambda(lambda x: x / 255.0), | |
Normalize(mean, std), | |
RandomShortSideScale(min_size=256, max_size=320), | |
RandomCrop(resize_to), | |
] | |
) | |
labels = list(trained_model.config.label2id.keys()) | |
def parse_video_to_clips(video_file): | |
"""A utility to parse the input videos """ | |
video_path_handler = VideoPathHandler() | |
video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file) | |
clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO) | |
# noinspection PyTypeChecker | |
clip_info: ClipInfoList = clip_sampler(0, video.duration, {}) | |
video_clips_list = [] | |
for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec): | |
video_clip = video.get_clip(clip_start, clip_end)["video"] | |
video_clips_list.append(inference_transform(video_clip)) | |
videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list]) | |
return videos_tensor | |
def infer(video_file): | |
videos_tensor = parse_video_to_clips(video_file) | |
inputs = {"pixel_values": videos_tensor} | |
# forward pass | |
with torch.no_grad(): | |
outputs = trained_model(**inputs) | |
multiple_logits = outputs.logits | |
logits = multiple_logits.sum(dim=0) | |
softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0) | |
confidences = {labels[i]: float(softmax_scores[i]) for i in range(len(labels))} | |
return confidences | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Video(type="file"), | |
outputs=gr.Label(num_top_classes=3), | |
examples=[ | |
["examples/babycrawling.mp4"], | |
["examples/baseball.mp4"], | |
["examples/balancebeam.mp4"], | |
], | |
title="VideoMAE fine-tuned on a subset of UCF-101", | |
description=( | |
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the" | |
" examples to load them. Read more at the links below." | |
), | |
article=( | |
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>" | |
" <center><a href='https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset' target='_blank'>Fine-tuned Model</a></center></div>" | |
), | |
allow_flagging=False, | |
allow_screenshot=False, | |
).launch() | |