File size: 3,834 Bytes
8760fb5
 
e0d185f
 
 
 
8760fb5
 
e0d185f
8760fb5
 
 
 
e0d185f
8760fb5
e0d185f
8760fb5
 
e0d185f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8760fb5
 
 
e0d185f
 
8760fb5
 
 
e0d185f
 
 
c5a535e
e0d185f
8760fb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import torch
from pytorchvideo.data import make_clip_sampler
from pytorchvideo.data.clip_sampling import ClipInfoList
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import (
    Normalize,
    UniformTemporalSubsample, RandomShortSideScale,
)
from torchvision.transforms import (
    Compose,
    Lambda,
    Resize, RandomCrop,
)
from transformers import pipeline


MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multiclass"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLIPS_FROM_SINGLE_VIDEO = 5

pipe = pipeline("video-classification", model=MODEL_CKPT)
trained_model = pipe.model
image_processor = pipe.image_processor

mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

# Validation and Test datasets' transformations.
inference_transform = Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(resize_to),
                ]
            )

labels = list(trained_model.config.label2id.keys())


def parse_video_to_clips(video_file):
    """A utility to parse the input videos """
    video_path_handler = VideoPathHandler()
    video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file)

    clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO)
    # noinspection PyTypeChecker
    clip_info: ClipInfoList = clip_sampler(0, video.duration, {})

    video_clips_list = []
    for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec):
        video_clip = video.get_clip(clip_start, clip_end)["video"]
        video_clips_list.append(inference_transform(video_clip))

    videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list])
    return videos_tensor


def infer(video_file):
    videos_tensor = parse_video_to_clips(video_file)
    inputs = {"pixel_values": videos_tensor}

    # forward pass
    with torch.no_grad():
        outputs = trained_model(**inputs)
        multiple_logits = outputs.logits
        logits = multiple_logits.sum(dim=0)
    softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
    confidences = {labels[i]: float(softmax_scores[i]) for i in range(len(labels))}
    return confidences


gr.Interface(
    fn=infer,
    inputs=gr.Video(type="file"),
    outputs=gr.Label(num_top_classes=3),
    examples=[
        ["examples/babycrawling.mp4"],
        ["examples/baseball.mp4"],
        ["examples/balancebeam.mp4"],
    ],
    title="VideoMAE fine-tuned on a subset of UCF-101",
    description=(
        "Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
        " examples to load them. Read more at the links below."
    ),
    article=(
        "<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
        " <center><a href='https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset' target='_blank'>Fine-tuned Model</a></center></div>"
    ),
    allow_flagging=False,
    allow_screenshot=False,
).launch()