VideoMAE / app.py
akhaliq's picture
akhaliq HF staff
Create new file
8323f05
raw history blame
No virus
1.77 kB
from decord import VideoReader, cpu
import torch
import numpy as np
from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
from huggingface_hub import hf_hub_download
import gradio as gr
np.random.seed(0)
def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
converted_len = int(clip_len * frame_sample_rate)
end_idx = np.random.randint(converted_len, seg_len)
start_idx = end_idx - converted_len
indices = np.linspace(start_idx, end_idx, num=clip_len)
indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
return indices
def inference(file_path):
# video clip consists of 300 frames (10 seconds at 30 FPS)
videoreader = VideoReader(file_path, num_threads=1, ctx=cpu(0))
# sample 16 frames
videoreader.seek(0)
indices = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(videoreader))
video = videoreader.get_batch(indices).asnumpy()
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
inputs = feature_extractor(list(video), return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 400 Kinetics-400 classes
predicted_label = logits.argmax(-1).item()
return model.config.id2label[predicted_label]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
video = gr.Video()
btn = gr.Button(value="Run")
with gr.Column():
label = gr.Textbox(label="Predicted Label")
translate_btn.click(inference, inputs=video, outputs=label)
demo.launch()