rickysk commited on
Commit
dccaadc
1 Parent(s): 5b71f2a

Upload 5 files

Browse files

Adding edited files

app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import imutils
4
+ import numpy as np
5
+ import torch
6
+ from pytorchvideo.transforms import (
7
+ ApplyTransformToKey,
8
+ Normalize,
9
+ RandomShortSideScale,
10
+ RemoveKey,
11
+ ShortSideScale,
12
+ UniformTemporalSubsample,
13
+ )
14
+ from torchvision.transforms import (
15
+ Compose,
16
+ Lambda,
17
+ RandomCrop,
18
+ RandomHorizontalFlip,
19
+ Resize,
20
+ )
21
+ from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
22
+
23
+ MODEL_CKPT = "rickysk/rickysk-videomae-base-ipm_all_videos"
24
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
27
+ PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
28
+
29
+ RESIZE_TO = PROCESSOR.size["shortest_edge"]
30
+ NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames
31
+ IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]}
32
+ VAL_TRANSFORMS = Compose(
33
+ [
34
+ UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE),
35
+ Lambda(lambda x: x / 255.0),
36
+ Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]),
37
+ Resize((RESIZE_TO, RESIZE_TO)),
38
+ ]
39
+ )
40
+ LABELS = list(MODEL.config.label2id.keys())
41
+
42
+
43
+ def parse_video(video_file):
44
+ """A utility to parse the input videos.
45
+
46
+ Reference: https://pyimagesearch.com/2018/11/12/yolo-object-detection-with-opencv/
47
+ """
48
+ vs = cv2.VideoCapture(video_file)
49
+
50
+ # try to determine the total number of frames in the video file
51
+ try:
52
+ prop = (
53
+ cv2.cv.CV_CAP_PROP_FRAME_COUNT
54
+ if imutils.is_cv2()
55
+ else cv2.CAP_PROP_FRAME_COUNT
56
+ )
57
+ total = int(vs.get(prop))
58
+ print("[INFO] {} total frames in video".format(total))
59
+
60
+ # an error occurred while trying to determine the total
61
+ # number of frames in the video file
62
+ except:
63
+ print("[INFO] could not determine # of frames in video")
64
+ print("[INFO] no approx. completion time can be provided")
65
+ total = -1
66
+
67
+ frames = []
68
+
69
+ # loop over frames from the video file stream
70
+ while True:
71
+ # read the next frame from the file
72
+ (grabbed, frame) = vs.read()
73
+ if frame is not None:
74
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
75
+ frames.append(frame)
76
+ # if the frame was not grabbed, then we have reached the end
77
+ # of the stream
78
+ if not grabbed:
79
+ break
80
+
81
+ return frames
82
+
83
+
84
+ def preprocess_video(frames: list):
85
+ """Utility to apply preprocessing transformations to a video tensor."""
86
+ # Each frame in the `frames` list has the shape: (height, width, num_channels).
87
+ # Collated together the `frames` has the the shape: (num_frames, height, width, num_channels).
88
+ # So, after converting the `frames` list to a torch tensor, we permute the shape
89
+ # such that it becomes (num_channels, num_frames, height, width) to make
90
+ # the shape compatible with the preprocessing transformations. After applying the
91
+ # preprocessing chain, we permute the shape to (num_frames, num_channels, height, width)
92
+ # to make it compatible with the model. Finally, we add a batch dimension so that our video
93
+ # classification model can operate on it.
94
+ video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype))
95
+ video_tensor = video_tensor.permute(
96
+ 3, 0, 1, 2
97
+ ) # (num_channels, num_frames, height, width)
98
+ video_tensor_pp = VAL_TRANSFORMS(video_tensor)
99
+ video_tensor_pp = video_tensor_pp.permute(
100
+ 1, 0, 2, 3
101
+ ) # (num_frames, num_channels, height, width)
102
+ video_tensor_pp = video_tensor_pp.unsqueeze(0)
103
+ return video_tensor_pp.to(DEVICE)
104
+
105
+
106
+ def infer(video_file):
107
+ frames = parse_video(video_file)
108
+ video_tensor = preprocess_video(frames)
109
+ inputs = {"pixel_values": video_tensor}
110
+
111
+ # forward pass
112
+ with torch.no_grad():
113
+ outputs = MODEL(**inputs)
114
+ logits = outputs.logits
115
+ softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
116
+ confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))}
117
+ return confidences
118
+
119
+
120
+ gr.Interface(
121
+ fn=infer,
122
+ inputs=gr.Video(type="file"),
123
+ outputs=gr.Label(num_top_classes=3),
124
+ examples=[
125
+ ["examples/babycrawling.mp4"],
126
+ ["examples/baseball.mp4"],
127
+ ["examples/balancebeam.mp4"],
128
+ ],
129
+ title="VideoMAE IPM",
130
+ description=(
131
+ "Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
132
+ " examples to load them. Read more at the links below."
133
+ ),
134
+ article=(
135
+ "<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
136
+ " <center><a href='https://huggingface.co/rickysk/rickysk-videomae-base-ipm_all_videos' target='_blank'>Fine-tuned Model</a></center></div>"
137
+ ),
138
+ allow_flagging=False,
139
+ allow_screenshot=False,
140
+ ).launch()
examples/babycrawling.mp4 ADDED
Binary file (309 kB). View file
 
examples/balancebeam.mp4 ADDED
Binary file (400 kB). View file
 
examples/baseball.mp4 ADDED
Binary file (199 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ imutils
3
+ numpy
4
+ torch
5
+ torchvision
6
+ pytorchvideo
7
+ transformers