Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# Choose the `slowfast_r50` model
|
3 |
+
model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True)
|
4 |
+
from typing import Dict
|
5 |
+
import json
|
6 |
+
import urllib
|
7 |
+
from torchvision.transforms import Compose, Lambda
|
8 |
+
from torchvision.transforms._transforms_video import (
|
9 |
+
CenterCropVideo,
|
10 |
+
NormalizeVideo,
|
11 |
+
)
|
12 |
+
from pytorchvideo.data.encoded_video import EncodedVideo
|
13 |
+
from pytorchvideo.transforms import (
|
14 |
+
ApplyTransformToKey,
|
15 |
+
ShortSideScale,
|
16 |
+
UniformTemporalSubsample,
|
17 |
+
UniformCropVideo
|
18 |
+
)
|
19 |
+
|
20 |
+
import gradio as gr
|
21 |
+
# Set to GPU or CPU
|
22 |
+
device = "cpu"
|
23 |
+
model = model.eval()
|
24 |
+
model = model.to(device)
|
25 |
+
json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
|
26 |
+
json_filename = "kinetics_classnames.json"
|
27 |
+
try: urllib.URLopener().retrieve(json_url, json_filename)
|
28 |
+
except: urllib.request.urlretrieve(json_url, json_filename)
|
29 |
+
with open(json_filename, "r") as f:
|
30 |
+
kinetics_classnames = json.load(f)
|
31 |
+
|
32 |
+
# Create an id to label name mapping
|
33 |
+
kinetics_id_to_classname = {}
|
34 |
+
for k, v in kinetics_classnames.items():
|
35 |
+
kinetics_id_to_classname[v] = str(k).replace('"', "")
|
36 |
+
side_size = 256
|
37 |
+
mean = [0.45, 0.45, 0.45]
|
38 |
+
std = [0.225, 0.225, 0.225]
|
39 |
+
crop_size = 256
|
40 |
+
num_frames = 32
|
41 |
+
sampling_rate = 2
|
42 |
+
frames_per_second = 30
|
43 |
+
slowfast_alpha = 4
|
44 |
+
num_clips = 10
|
45 |
+
num_crops = 3
|
46 |
+
|
47 |
+
class PackPathway(torch.nn.Module):
|
48 |
+
"""
|
49 |
+
Transform for converting video frames as a list of tensors.
|
50 |
+
"""
|
51 |
+
def __init__(self):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
def forward(self, frames: torch.Tensor):
|
55 |
+
fast_pathway = frames
|
56 |
+
# Perform temporal sampling from the fast pathway.
|
57 |
+
slow_pathway = torch.index_select(
|
58 |
+
frames,
|
59 |
+
1,
|
60 |
+
torch.linspace(
|
61 |
+
0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha
|
62 |
+
).long(),
|
63 |
+
)
|
64 |
+
frame_list = [slow_pathway, fast_pathway]
|
65 |
+
return frame_list
|
66 |
+
|
67 |
+
transform = ApplyTransformToKey(
|
68 |
+
key="video",
|
69 |
+
transform=Compose(
|
70 |
+
[
|
71 |
+
UniformTemporalSubsample(num_frames),
|
72 |
+
Lambda(lambda x: x/255.0),
|
73 |
+
NormalizeVideo(mean, std),
|
74 |
+
ShortSideScale(
|
75 |
+
size=side_size
|
76 |
+
),
|
77 |
+
CenterCropVideo(crop_size),
|
78 |
+
PackPathway()
|
79 |
+
]
|
80 |
+
),
|
81 |
+
)
|
82 |
+
|
83 |
+
# The duration of the input clip is also specific to the model.
|
84 |
+
clip_duration = (num_frames * sampling_rate)/frames_per_second
|
85 |
+
url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4"
|
86 |
+
video_path = 'archery.mp4'
|
87 |
+
try: urllib.URLopener().retrieve(url_link, video_path)
|
88 |
+
except: urllib.request.urlretrieve(url_link, video_path)
|
89 |
+
# Select the duration of the clip to load by specifying the start and end duration
|
90 |
+
# The start_sec should correspond to where the action occurs in the video
|
91 |
+
|
92 |
+
def inference(in_vid):
|
93 |
+
start_sec = 0
|
94 |
+
end_sec = start_sec + clip_duration
|
95 |
+
|
96 |
+
# Initialize an EncodedVideo helper class and load the video
|
97 |
+
video = EncodedVideo.from_path(in_vid)
|
98 |
+
|
99 |
+
# Load the desired clip
|
100 |
+
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
|
101 |
+
|
102 |
+
# Apply a transform to normalize the video input
|
103 |
+
video_data = transform(video_data)
|
104 |
+
|
105 |
+
# Move the inputs to the desired device
|
106 |
+
inputs = video_data["video"]
|
107 |
+
inputs = [i.to(device)[None, ...] for i in inputs]
|
108 |
+
# Pass the input clip through the model
|
109 |
+
preds = model(inputs)
|
110 |
+
|
111 |
+
# Get the predicted classes
|
112 |
+
post_act = torch.nn.Softmax(dim=1)
|
113 |
+
preds = post_act(preds)
|
114 |
+
pred_classes = preds.topk(k=5).indices[0]
|
115 |
+
|
116 |
+
# Map the predicted classes to the label names
|
117 |
+
pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes]
|
118 |
+
return "%s" % ", ".join(pred_class_names)
|
119 |
+
|
120 |
+
inputs = gr.inputs.Video(label="Input Video")
|
121 |
+
outputs = gr.outputs.Textbox(label="Top 5 predicted labels")
|
122 |
+
|
123 |
+
title = "SLOWFAST"
|
124 |
+
description = "demo for SLOWFAST, SlowFast networks pretrained on the Kinetics 400 dataset. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
|
125 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo'>Github Repo</a></p>"
|
126 |
+
|
127 |
+
examples = [
|
128 |
+
['archery.mp4']
|
129 |
+
]
|
130 |
+
|
131 |
+
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch(debug=True)
|