File size: 4,368 Bytes
8760fb5 781cb20 8760fb5 e0d185f 8760fb5 e0d185f 8760fb5 e0d185f 8760fb5 614fa07 8760fb5 781cb20 e0d185f 614fa07 e0d185f 781cb20 0cd2e7c 781cb20 0cd2e7c 781cb20 e0d185f 614fa07 8760fb5 e0d185f 8760fb5 e0d185f c5a535e e0d185f 8760fb5 117732b 8760fb5 117732b 8760fb5 b6638d1 8760fb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import tempfile
import torch
from pytorchvideo.data import make_clip_sampler
from pytorchvideo.data.clip_sampling import ClipInfoList
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import (
Normalize,
UniformTemporalSubsample, RandomShortSideScale,
)
from torchvision.transforms import (
Compose,
Lambda,
Resize, RandomCrop,
)
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
from video_utils import change_video_resolution_and_fps
MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multiclass"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLIPS_FROM_SINGLE_VIDEO = 5
trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
height = width = image_processor.size["shortest_edge"]
else:
height = image_processor.size["height"]
width = image_processor.size["width"]
resize_to = (height, width)
num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps
# Validation and Test datasets' transformations.
inference_transform = Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(resize_to),
]
)
labels = list(trained_model.config.label2id.keys())
def parse_video_to_clips(video_file):
"""A utility to parse the input videos """
new_resolution = (320, 256)
new_fps = 30
acceptable_fps_violation = 5
with tempfile.NamedTemporaryFile() as new_video:
print(new_video.name)
change_video_resolution_and_fps(video_file, new_video.name, new_resolution, new_fps, acceptable_fps_violation)
video_path_handler = VideoPathHandler()
video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file)
clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO)
# noinspection PyTypeChecker
clip_info: ClipInfoList = clip_sampler(0, video.duration, {})
video_clips_list = []
for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec):
video_clip = video.get_clip(clip_start, clip_end)["video"]
video_clips_list.append(inference_transform(video_clip))
videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list])
return videos_tensor.to(DEVICE)
def infer(video_file):
videos_tensor = parse_video_to_clips(video_file)
inputs = {"pixel_values": videos_tensor}
# forward pass
with torch.no_grad():
outputs = trained_model(**inputs)
multiple_logits = outputs.logits
logits = multiple_logits.sum(dim=0)
softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
confidences = {labels[i]: float(softmax_scores[i]) for i in range(len(labels))}
return confidences
gr.Interface(
fn=infer,
inputs=gr.Video(type="file"),
outputs=gr.Label(num_top_classes=3),
examples=[
["examples/DUNK.avi"],
["examples/FLOATING_JUMP_SHOT.avi"],
["examples/JUMP_SHOT.avi"],
["examples/REVERSE_LAYUP.avi"],
["examples/TURNAROUND_HOOK_SHOT.avi"],
],
title="VideoMAE fine-tuned on nba data",
description=(
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
" examples to load them. Read more at the links below."
),
article=(
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
" <center><a href='https://huggingface.co/omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multiclass' target='_blank'>Fine-tuned Model</a></center></div>"
),
allow_flagging=False,
allow_screenshot=False,
).launch()
|