akhaliq HF staff commited on
Commit
2423d1f
1 Parent(s): 6e0e961

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # Choose the `slowfast_r50` model
3
+ model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True)
4
+ from typing import Dict
5
+ import json
6
+ import urllib
7
+ from torchvision.transforms import Compose, Lambda
8
+ from torchvision.transforms._transforms_video import (
9
+ CenterCropVideo,
10
+ NormalizeVideo,
11
+ )
12
+ from pytorchvideo.data.encoded_video import EncodedVideo
13
+ from pytorchvideo.transforms import (
14
+ ApplyTransformToKey,
15
+ ShortSideScale,
16
+ UniformTemporalSubsample,
17
+ UniformCropVideo
18
+ )
19
+
20
+ import gradio as gr
21
+ # Set to GPU or CPU
22
+ device = "cpu"
23
+ model = model.eval()
24
+ model = model.to(device)
25
+ json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
26
+ json_filename = "kinetics_classnames.json"
27
+ try: urllib.URLopener().retrieve(json_url, json_filename)
28
+ except: urllib.request.urlretrieve(json_url, json_filename)
29
+ with open(json_filename, "r") as f:
30
+ kinetics_classnames = json.load(f)
31
+
32
+ # Create an id to label name mapping
33
+ kinetics_id_to_classname = {}
34
+ for k, v in kinetics_classnames.items():
35
+ kinetics_id_to_classname[v] = str(k).replace('"', "")
36
+ side_size = 256
37
+ mean = [0.45, 0.45, 0.45]
38
+ std = [0.225, 0.225, 0.225]
39
+ crop_size = 256
40
+ num_frames = 32
41
+ sampling_rate = 2
42
+ frames_per_second = 30
43
+ slowfast_alpha = 4
44
+ num_clips = 10
45
+ num_crops = 3
46
+
47
+ class PackPathway(torch.nn.Module):
48
+ """
49
+ Transform for converting video frames as a list of tensors.
50
+ """
51
+ def __init__(self):
52
+ super().__init__()
53
+
54
+ def forward(self, frames: torch.Tensor):
55
+ fast_pathway = frames
56
+ # Perform temporal sampling from the fast pathway.
57
+ slow_pathway = torch.index_select(
58
+ frames,
59
+ 1,
60
+ torch.linspace(
61
+ 0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha
62
+ ).long(),
63
+ )
64
+ frame_list = [slow_pathway, fast_pathway]
65
+ return frame_list
66
+
67
+ transform = ApplyTransformToKey(
68
+ key="video",
69
+ transform=Compose(
70
+ [
71
+ UniformTemporalSubsample(num_frames),
72
+ Lambda(lambda x: x/255.0),
73
+ NormalizeVideo(mean, std),
74
+ ShortSideScale(
75
+ size=side_size
76
+ ),
77
+ CenterCropVideo(crop_size),
78
+ PackPathway()
79
+ ]
80
+ ),
81
+ )
82
+
83
+ # The duration of the input clip is also specific to the model.
84
+ clip_duration = (num_frames * sampling_rate)/frames_per_second
85
+ url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4"
86
+ video_path = 'archery.mp4'
87
+ try: urllib.URLopener().retrieve(url_link, video_path)
88
+ except: urllib.request.urlretrieve(url_link, video_path)
89
+ # Select the duration of the clip to load by specifying the start and end duration
90
+ # The start_sec should correspond to where the action occurs in the video
91
+
92
+ def inference(in_vid):
93
+ start_sec = 0
94
+ end_sec = start_sec + clip_duration
95
+
96
+ # Initialize an EncodedVideo helper class and load the video
97
+ video = EncodedVideo.from_path(in_vid)
98
+
99
+ # Load the desired clip
100
+ video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
101
+
102
+ # Apply a transform to normalize the video input
103
+ video_data = transform(video_data)
104
+
105
+ # Move the inputs to the desired device
106
+ inputs = video_data["video"]
107
+ inputs = [i.to(device)[None, ...] for i in inputs]
108
+ # Pass the input clip through the model
109
+ preds = model(inputs)
110
+
111
+ # Get the predicted classes
112
+ post_act = torch.nn.Softmax(dim=1)
113
+ preds = post_act(preds)
114
+ pred_classes = preds.topk(k=5).indices[0]
115
+
116
+ # Map the predicted classes to the label names
117
+ pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes]
118
+ return "%s" % ", ".join(pred_class_names)
119
+
120
+ inputs = gr.inputs.Video(label="Input Video")
121
+ outputs = gr.outputs.Textbox(label="Top 5 predicted labels")
122
+
123
+ title = "SLOWFAST"
124
+ description = "demo for SLOWFAST, SlowFast networks pretrained on the Kinetics 400 dataset. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
125
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo'>Github Repo</a></p>"
126
+
127
+ examples = [
128
+ ['archery.mp4']
129
+ ]
130
+
131
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch(debug=True)