File size: 4,368 Bytes
8760fb5
781cb20
8760fb5
e0d185f
 
 
 
8760fb5
 
e0d185f
8760fb5
 
 
 
e0d185f
8760fb5
614fa07
8760fb5
781cb20
 
e0d185f
 
 
 
614fa07
 
e0d185f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781cb20
 
0cd2e7c
781cb20
 
0cd2e7c
781cb20
 
e0d185f
 
 
 
 
 
 
 
 
 
 
614fa07
8760fb5
 
 
e0d185f
 
8760fb5
 
 
e0d185f
 
 
c5a535e
e0d185f
8760fb5
 
 
 
 
 
 
 
117732b
 
 
 
 
8760fb5
117732b
8760fb5
 
 
 
 
 
b6638d1
8760fb5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import tempfile
import torch
from pytorchvideo.data import make_clip_sampler
from pytorchvideo.data.clip_sampling import ClipInfoList
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import (
    Normalize,
    UniformTemporalSubsample, RandomShortSideScale,
)
from torchvision.transforms import (
    Compose,
    Lambda,
    Resize, RandomCrop,
)
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor

from video_utils import change_video_resolution_and_fps

MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multiclass"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLIPS_FROM_SINGLE_VIDEO = 5

trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)

mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

# Validation and Test datasets' transformations.
inference_transform = Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(resize_to),
                ]
            )

labels = list(trained_model.config.label2id.keys())


def parse_video_to_clips(video_file):
    """A utility to parse the input videos """
    new_resolution = (320, 256)
    new_fps = 30
    acceptable_fps_violation = 5
    with tempfile.NamedTemporaryFile() as new_video:
        print(new_video.name)
        change_video_resolution_and_fps(video_file, new_video.name, new_resolution, new_fps, acceptable_fps_violation)
        video_path_handler = VideoPathHandler()
        video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file)

    clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO)
    # noinspection PyTypeChecker
    clip_info: ClipInfoList = clip_sampler(0, video.duration, {})

    video_clips_list = []
    for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec):
        video_clip = video.get_clip(clip_start, clip_end)["video"]
        video_clips_list.append(inference_transform(video_clip))

    videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list])
    return videos_tensor.to(DEVICE)


def infer(video_file):
    videos_tensor = parse_video_to_clips(video_file)
    inputs = {"pixel_values": videos_tensor}

    # forward pass
    with torch.no_grad():
        outputs = trained_model(**inputs)
        multiple_logits = outputs.logits
        logits = multiple_logits.sum(dim=0)
    softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
    confidences = {labels[i]: float(softmax_scores[i]) for i in range(len(labels))}
    return confidences


gr.Interface(
    fn=infer,
    inputs=gr.Video(type="file"),
    outputs=gr.Label(num_top_classes=3),
    examples=[
        ["examples/DUNK.avi"],
        ["examples/FLOATING_JUMP_SHOT.avi"],
        ["examples/JUMP_SHOT.avi"],
        ["examples/REVERSE_LAYUP.avi"],
        ["examples/TURNAROUND_HOOK_SHOT.avi"],
    ],
    title="VideoMAE fine-tuned on nba data",
    description=(
        "Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
        " examples to load them. Read more at the links below."
    ),
    article=(
        "<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
        " <center><a href='https://huggingface.co/omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multiclass' target='_blank'>Fine-tuned Model</a></center></div>"
    ),
    allow_flagging=False,
    allow_screenshot=False,
).launch()