|
import torch |
|
|
|
model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True) |
|
from typing import Dict |
|
import json |
|
import urllib |
|
from torchvision.transforms import Compose, Lambda |
|
from torchvision.transforms._transforms_video import ( |
|
CenterCropVideo, |
|
NormalizeVideo, |
|
) |
|
from pytorchvideo.data.encoded_video import EncodedVideo |
|
from pytorchvideo.transforms import ( |
|
ApplyTransformToKey, |
|
ShortSideScale, |
|
UniformTemporalSubsample, |
|
UniformCropVideo |
|
) |
|
|
|
import gradio as gr |
|
|
|
device = "cpu" |
|
model = model.eval() |
|
model = model.to(device) |
|
|
|
json_filename = "./kinetics_classnames.json" |
|
|
|
|
|
with open(json_filename, "r") as f: |
|
kinetics_classnames = json.load(f) |
|
|
|
|
|
kinetics_id_to_classname = {} |
|
for k, v in kinetics_classnames.items(): |
|
kinetics_id_to_classname[v] = str(k).replace('"', "") |
|
side_size = 256 |
|
mean = [0.45, 0.45, 0.45] |
|
std = [0.225, 0.225, 0.225] |
|
crop_size = 256 |
|
num_frames = 32 |
|
sampling_rate = 2 |
|
frames_per_second = 30 |
|
slowfast_alpha = 4 |
|
num_clips = 10 |
|
num_crops = 3 |
|
|
|
class PackPathway(torch.nn.Module): |
|
""" |
|
Transform for converting video frames as a list of tensors. |
|
""" |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, frames: torch.Tensor): |
|
fast_pathway = frames |
|
|
|
slow_pathway = torch.index_select( |
|
frames, |
|
1, |
|
torch.linspace( |
|
0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha |
|
).long(), |
|
) |
|
frame_list = [slow_pathway, fast_pathway] |
|
return frame_list |
|
|
|
transform = ApplyTransformToKey( |
|
key="video", |
|
transform=Compose( |
|
[ |
|
UniformTemporalSubsample(num_frames), |
|
Lambda(lambda x: x/255.0), |
|
NormalizeVideo(mean, std), |
|
ShortSideScale( |
|
size=side_size |
|
), |
|
CenterCropVideo(crop_size), |
|
PackPathway() |
|
] |
|
), |
|
) |
|
|
|
|
|
clip_duration = (num_frames * sampling_rate)/frames_per_second |
|
url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4" |
|
video_path = 'archery.mp4' |
|
try: urllib.URLopener().retrieve(url_link, video_path) |
|
except: urllib.request.urlretrieve(url_link, video_path) |
|
|
|
|
|
|
|
def inference(in_vid): |
|
start_sec = 0 |
|
end_sec = start_sec + clip_duration |
|
|
|
|
|
video = EncodedVideo.from_path(in_vid) |
|
|
|
|
|
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) |
|
|
|
|
|
video_data = transform(video_data) |
|
|
|
|
|
inputs = video_data["video"] |
|
inputs = [i.to(device)[None, ...] for i in inputs] |
|
|
|
preds = model(inputs) |
|
|
|
|
|
post_act = torch.nn.Softmax(dim=1) |
|
preds = post_act(preds) |
|
pred_classes = preds.topk(k=5).indices[0] |
|
|
|
|
|
pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes] |
|
return "%s" % ", ".join(pred_class_names) |
|
|
|
inputs = gr.inputs.Video(label="Input Video") |
|
outputs = gr.outputs.Textbox(label="Top 5 predicted labels") |
|
|
|
title = "SLOWFAST" |
|
description = "demo for SLOWFAST, SlowFast networks pretrained on the Kinetics 400 dataset. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below." |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo'>Github Repo</a></p>" |
|
|
|
examples = [ |
|
['archery.mp4'] |
|
] |
|
|
|
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch(debug=True) |