File size: 5,809 Bytes
a0fd97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f8282c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72d87fd
 
 
 
 
 
 
 
 
1f8282c
 
72d87fd
 
 
1f8282c
72d87fd
 
 
1f8282c
72d87fd
 
 
 
 
1f8282c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0fd97b
1f8282c
da13376
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
import gradio as gr
from utils import (
    create_gif_from_video_file,
    download_youtube_video,
    get_num_total_frames,
)
from transformers import pipeline
from huggingface_hub import HfApi, ModelSearchArguments, ModelFilter

FRAME_SAMPLING_RATE = 4
DEFAULT_MODEL = "facebook/timesformer-base-finetuned-k400"

VALID_VIDEOCLASSIFICATION_MODELS = [
    "MCG-NJU/videomae-large-finetuned-kinetics",
    "facebook/timesformer-base-finetuned-k400",
    "fcakyon/timesformer-large-finetuned-k400",
    "MCG-NJU/videomae-base-finetuned-kinetics",
    "facebook/timesformer-base-finetuned-k600",
    "fcakyon/timesformer-large-finetuned-k600",
    "facebook/timesformer-hr-finetuned-k400",
    "facebook/timesformer-hr-finetuned-k600",
    "facebook/timesformer-base-finetuned-ssv2",
    "fcakyon/timesformer-large-finetuned-ssv2",
    "facebook/timesformer-hr-finetuned-ssv2",
    "MCG-NJU/videomae-base-finetuned-ssv2",
    "MCG-NJU/videomae-base-short-finetuned-kinetics",
    "MCG-NJU/videomae-base-short-ssv2",
    "MCG-NJU/videomae-base-short-finetuned-ssv2",
    "sayakpaul/videomae-base-finetuned-ucf101-subset",
    "nateraw/videomae-base-finetuned-ucf101",
    "MCG-NJU/videomae-base-ssv2",
    "zahrav/videomae-base-finetuned-ucf101-subset",
]


pipe = pipeline(
    task="video-classification",
    model=DEFAULT_MODEL,
    top_k=5,
    frame_sampling_rate=FRAME_SAMPLING_RATE,
)


examples = [
    ["https://www.youtube.com/watch?v=huAJ9dC5lmI"],
    ["https://www.youtube.com/watch?v=wvcWt6u5HTg"],
    ["https://www.youtube.com/watch?v=-3kZSi5qjRM"],
    ["https://www.youtube.com/watch?v=-6usjfP8hys"],
    ["https://www.youtube.com/watch?v=BDHub0gBGtc"],
    ["https://www.youtube.com/watch?v=B9ea7YyCP6E"],
    ["https://www.youtube.com/watch?v=BBkpaeJBKmk"],
    ["https://www.youtube.com/watch?v=BBqU8Apee_g"],
    ["https://www.youtube.com/watch?v=B8OdMwVwyXc"],
    ["https://www.youtube.com/watch?v=I7cwq6_4QtM"],
    ["https://www.youtube.com/watch?v=Z0mJDXpNhYA"],
    ["https://www.youtube.com/watch?v=QkQQjFGnZlg"],
    ["https://www.youtube.com/watch?v=IQaoRUQif14"],
]


def get_video_model_names():
    model_args = ModelSearchArguments()
    filter = ModelFilter(
        task=model_args.pipeline_tag.VideoClassification,
        library=model_args.library.Transformers,
    )
    api = HfApi()
    video_models = list(
        iter(api.list_models(filter=filter, sort="downloads", direction=-1))
    )
    video_models = [video_model.id for video_model in video_models]
    return video_models


def select_model(model_name):
    global pipe
    pipe = pipeline(
        task="video-classification",
        model=model_name,
        top_k=5,
        frame_sampling_rate=FRAME_SAMPLING_RATE,
    )


def predict(youtube_url_or_file_path):

    if youtube_url_or_file_path.startswith("http"):
        video_path = download_youtube_video(youtube_url_or_file_path)
    else:
        video_path = youtube_url_or_file_path

    # rearrange sampling rate based on video length and model input length
    num_total_frames = get_num_total_frames(video_path)
    num_model_input_frames = pipe.model.config.num_frames
    if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames:
        frame_sampling_rate = num_total_frames // num_model_input_frames
    else:
        frame_sampling_rate = FRAME_SAMPLING_RATE

    gif_path = create_gif_from_video_file(
        video_path, frame_sampling_rate=frame_sampling_rate, save_path="video.gif"
    )

    # run inference
    results = pipe(videos=video_path, frame_sampling_rate=frame_sampling_rate)

    os.remove(video_path)

    label_to_score = {result["label"]: result["score"] for result in results}

    return label_to_score, gif_path


# Create the Gradio app
app = gr.Blocks()

with app:
    # Title and description
    gr.Markdown("# **<p align='center'>Video Classification with 🤗 Transformers</p>**")
    gr.Markdown(
        """
        <p style='text-align: center'>
        Perform video classification with <a href='https://huggingface.co/models?pipeline_tag=video-classification&library=transformers' target='_blank'>HuggingFace Transformers video models</a>.
        <br> For zero-shot classification, you can use the <a href='https://huggingface.co/spaces/fcakyon/zero-shot-video-classification' target='_blank'>zero-shot classification demo</a>.
        </p>
        """
    )

    # Model Selection and Input
    gr.Label("Model:")
    model_names_dropdown = gr.Dropdown(
        choices=VALID_VIDEOCLASSIFICATION_MODELS,
        value=DEFAULT_MODEL,
    )
    model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown)

    # Tabs for Youtube URL and Local File
    with gr.Row():
        with gr.Column():
            gr.Markdown("### **Provide a Youtube video URL**")
            youtube_url = gr.Textbox(label="Youtube URL:")
            youtube_url_predict_btn = gr.Button(value="Predict")
        with gr.Column():
            gr.Markdown("### **Upload a video file**")
            video_file = gr.Video(label="Video File:")
            local_video_predict_btn = gr.Button(value="Predict")

    # Display Input Clip
    video_gif = gr.Image(label="Input Clip")

    # Display Predictions
    predictions = gr.Label(label="Predictions:", num_top_classes=5)

    # Examples and Prediction Buttons
    gr.Markdown("**Examples:**")
    gr.Examples(
        examples,
        youtube_url,
        [predictions, video_gif],
        fn=predict,
        cache_examples=True,
    )

    # Click handlers for prediction buttons
    youtube_url_predict_btn.click(
        predict, inputs=youtube_url, outputs=[predictions, video_gif]
    )
    local_video_predict_btn.click(
        predict, inputs=video_file, outputs=[predictions, video_gif]
    )

# Launch the Gradio app
app.launch()