akhaliq HF staff commited on
Commit
58e78f1
1 Parent(s): 57a3c98

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # Choose the `slow_r50` model
3
+ model = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', pretrained=True)
4
+ import json
5
+ import urllib
6
+ from pytorchvideo.data.encoded_video import EncodedVideo
7
+
8
+ from torchvision.transforms import Compose, Lambda
9
+ from torchvision.transforms._transforms_video import (
10
+ CenterCropVideo,
11
+ NormalizeVideo,
12
+ )
13
+ from pytorchvideo.transforms import (
14
+ ApplyTransformToKey,
15
+ ShortSideScale,
16
+ UniformTemporalSubsample
17
+ )
18
+
19
+ import gradio as gr
20
+ # Set to GPU or CPU
21
+ device = "cpu"
22
+ model = model.eval()
23
+ model = model.to(device)
24
+ json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
25
+ json_filename = "kinetics_classnames.json"
26
+ try: urllib.URLopener().retrieve(json_url, json_filename)
27
+ except: urllib.request.urlretrieve(json_url, json_filename)
28
+ with open(json_filename, "r") as f:
29
+ kinetics_classnames = json.load(f)
30
+
31
+ # Create an id to label name mapping
32
+ kinetics_id_to_classname = {}
33
+ for k, v in kinetics_classnames.items():
34
+ kinetics_id_to_classname[v] = str(k).replace('"', "")
35
+ side_size = 256
36
+ mean = [0.45, 0.45, 0.45]
37
+ std = [0.225, 0.225, 0.225]
38
+ crop_size = 256
39
+ num_frames = 8
40
+ sampling_rate = 8
41
+ frames_per_second = 30
42
+
43
+ # Note that this transform is specific to the slow_R50 model.
44
+ transform = ApplyTransformToKey(
45
+ key="video",
46
+ transform=Compose(
47
+ [
48
+ UniformTemporalSubsample(num_frames),
49
+ Lambda(lambda x: x/255.0),
50
+ NormalizeVideo(mean, std),
51
+ ShortSideScale(
52
+ size=side_size
53
+ ),
54
+ CenterCropVideo(crop_size=(crop_size, crop_size))
55
+ ]
56
+ ),
57
+ )
58
+
59
+ # The duration of the input clip is also specific to the model.
60
+ clip_duration = (num_frames * sampling_rate)/frames_per_second
61
+ url_link = "https://dl.fbaipublicfiles.com/pytorchvideo/projects/archery.mp4"
62
+ video_path = 'archery.mp4'
63
+ try: urllib.URLopener().retrieve(url_link, video_path)
64
+ except: urllib.request.urlretrieve(url_link, video_path)
65
+ # Select the duration of the clip to load by specifying the start and end duration
66
+ # The start_sec should correspond to where the action occurs in the video
67
+ def inference(in_vid):
68
+ start_sec = 0
69
+ end_sec = start_sec + clip_duration
70
+
71
+ # Initialize an EncodedVideo helper class and load the video
72
+ video = EncodedVideo.from_path(in_vid)
73
+
74
+ # Load the desired clip
75
+ video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)
76
+
77
+ # Apply a transform to normalize the video input
78
+ video_data = transform(video_data)
79
+
80
+ # Move the inputs to the desired device
81
+ inputs = video_data["video"]
82
+ inputs = inputs.to(device)
83
+ # Pass the input clip through the model
84
+ preds = model(inputs[None, ...])
85
+
86
+ # Get the predicted classes
87
+ post_act = torch.nn.Softmax(dim=1)
88
+ preds = post_act(preds)
89
+ pred_classes = preds.topk(k=5).indices[0]
90
+
91
+ # Map the predicted classes to the label names
92
+ pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes]
93
+ return "%s" % ", ".join(pred_class_names)
94
+
95
+ inputs = gr.inputs.Video(label="Input Video")
96
+ outputs = gr.outputs.Textbox(label="Top 5 predicted labels")
97
+
98
+ title = "3D RESNET"
99
+ description = "demo for 3D RESNET, Resnet Style Video classification networks pretrained on the Kinetics 400 dataset. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
100
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1812.03982'>SlowFast Networks for Video Recognition</a> | <a href='https://github.com/facebookresearch/pytorchvideo'>Github Repo</a></p>"
101
+
102
+ examples = [
103
+ ['archery.mp4']
104
+ ]
105
+
106
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples, analytics_enabled=False).launch()