import gradio as gr import tempfile import torch from pytorchvideo.data import make_clip_sampler from pytorchvideo.data.clip_sampling import ClipInfoList from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV from pytorchvideo.data.video import VideoPathHandler from pytorchvideo.transforms import ( Normalize, UniformTemporalSubsample, RandomShortSideScale, ) from torchvision.transforms import ( Compose, Lambda, Resize, RandomCrop, ) from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor from video_utils import change_video_resolution_and_fps import numpy as np MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multilabel" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CLIPS_FROM_SINGLE_VIDEO = 5 trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE) image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT) mean = image_processor.image_mean std = image_processor.image_std if "shortest_edge" in image_processor.size: height = width = image_processor.size["shortest_edge"] else: height = image_processor.size["height"] width = image_processor.size["width"] resize_to = (height, width) num_frames_to_sample = trained_model.config.num_frames sample_rate = 4 fps = 30 clip_duration = num_frames_to_sample * sample_rate / fps # Validation and Test datasets' transformations. inference_transform = Compose( [ UniformTemporalSubsample(num_frames_to_sample), Lambda(lambda x: x / 255.0), Normalize(mean, std), RandomShortSideScale(min_size=256, max_size=320), RandomCrop(resize_to), ] ) num_labels = trained_model.config.num_labels labels = [trained_model.config.id2label[i] for i in range(num_labels)] def parse_video_to_clips(video_file): """A utility to parse the input videos """ new_resolution = (320, 256) new_fps = 30 acceptable_fps_violation = 5 with tempfile.NamedTemporaryFile() as new_video: print(new_video.name) change_video_resolution_and_fps(video_file, new_video.name, new_resolution, new_fps, acceptable_fps_violation) video_path_handler = VideoPathHandler() video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file) clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO) # noinspection PyTypeChecker clip_info: ClipInfoList = clip_sampler(0, video.duration, {}) video_clips_list = [] for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec): video_clip = video.get_clip(clip_start, clip_end)["video"] video_clips_list.append(inference_transform(video_clip)) videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list]) return videos_tensor.to(DEVICE) def infer(video_file, threshold=0.5): videos_tensor = parse_video_to_clips(video_file) inputs = {"pixel_values": videos_tensor} # forward pass with torch.no_grad(): outputs = trained_model(**inputs) multiple_logits = outputs.logits logits = multiple_logits.mean(dim=0) # first, apply sigmoid on logits sigmoid = torch.nn.Sigmoid() sigmoid_scores = sigmoid(torch.from_numpy(np.array(logits))).squeeze(0) # next, use threshold to turn them into integer predictions confidences = {labels[i]: float(sigmoid_scores[i]) for i in range(len(labels))} return confidences gr.Interface( fn=infer, inputs=gr.Video(type="file"), outputs=gr.Label(num_top_classes=10), examples=[ ["examples/DUNK.avi"], ["examples/FLOATING_JUMP_SHOT.avi"], ["examples/JUMP_SHOT.avi"], ["examples/REVERSE_LAYUP.avi"], ["examples/TURNAROUND_HOOK_SHOT.avi"], ], title="VideoMAE fine-tuned on nba data", description=( "Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the" " examples to load them. Read more at the links below." ), article=( "
VideoMAE" "
Fine-tuned Model
" ), allow_flagging=False, allow_screenshot=False, ).launch()