Spaces:
Runtime error
Runtime error
File size: 4,801 Bytes
0aaa850 233b873 0aaa850 233b873 0aaa850 233b873 0aaa850 233b873 0aaa850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import cv2
import imutils
import gradio as gr
import numpy as np
import torch
from pytorchvideo.transforms import (
ApplyTransformToKey,
Normalize,
RandomShortSideScale,
RemoveKey,
ShortSideScale,
UniformTemporalSubsample,
)
from torchvision.transforms import (
Compose,
Lambda,
RandomCrop,
RandomHorizontalFlip,
Resize,
)
from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification
MODEL_CKPT = "Aryanikale23/Signlanguage"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
PROCESSOR = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
RESIZE_TO = PROCESSOR.size["shortest_edge"]
NUM_FRAMES_TO_SAMPLE = MODEL.config.num_frames
IMAGE_STATS = {"image_mean": [0.485, 0.456, 0.406], "image_std": [0.229, 0.224, 0.225]}
VAL_TRANSFORMS = Compose(
[
UniformTemporalSubsample(NUM_FRAMES_TO_SAMPLE),
Lambda(lambda x: x / 255.0),
Normalize(IMAGE_STATS["image_mean"], IMAGE_STATS["image_std"]),
Resize((RESIZE_TO, RESIZE_TO)),
]
)
LABELS = list(MODEL.config.label2id.keys())
def parse_video(video_file):
"""A utility to parse the input videos.
Reference: https://pyimagesearch.com/2018/11/12/yolo-object-detection-with-opencv/
"""
vs = cv2.VideoCapture(video_file)
# try to determine the total number of frames in the video file
try:
prop = (
cv2.cv.CV_CAP_PROP_FRAME_COUNT
if imutils.is_cv2()
else cv2.CAP_PROP_FRAME_COUNT)
total = int(vs.get(prop))
print("[INFO] {} total frames in video".format(total))
# an error occurred while trying to determine the total
# number of frames in the video file
except:
print("[INFO] could not determine # of frames in video")
print("[INFO] no approx. completion time can be provided")
total = -1
frames = []
# loop over frames from the video file stream
while True:
# read the next frame from the file
(grabbed, frame) = vs.read()
if frame is not None:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
# if the frame was not grabbed, then we have reached the end
# of the stream
if not grabbed:
break
return frames
def preprocess_video(frames: list):
"""Utility to apply preprocessing transformations to a video tensor."""
# Each frame in the `frames` list has the shape: (height, width, num_channels).
# Collated together the `frames` has the the shape: (num_frames, height, width, num_channels).
# So, after converting the `frames` list to a torch tensor, we permute the shape
# such that it becomes (num_channels, num_frames, height, width) to make
# the shape compatible with the preprocessing transformations. After applying the
# preprocessing chain, we permute the shape to (num_frames, num_channels, height, width)
# to make it compatible with the model. Finally, we add a batch dimension so that our video
# classification model can operate on it.
video_tensor = torch.tensor(np.array(frames).astype(frames[0].dtype))
video_tensor = video_tensor.permute(3, 0, 1, 2) # (num_channels, num_frames, height, width)
video_tensor_pp = VAL_TRANSFORMS(video_tensor)
video_tensor_pp = video_tensor_pp.permute(1, 0, 2, 3) # (num_frames, num_channels, height, width)
video_tensor_pp = video_tensor_pp.unsqueeze(0)
return video_tensor_pp.to(DEVICE)
def infer(video_file):
frames = parse_video(video_file)
video_tensor = preprocess_video(frames)
inputs = {"pixel_values": video_tensor}
# forward pass
with torch.no_grad():
outputs = MODEL(**inputs)
logits = outputs.logits
softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0)
confidences = {LABELS[i]: float(softmax_scores[i]) for i in range(len(LABELS))}
return confidences
gr.Interface(
fn=infer,
inputs=gr.Video(type="file"),
outputs=gr.Label(num_top_classes=3),
title="VideoMAE fine-tuned on a subset of UCF-101",
description=(
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
" examples to load them. Read more at the links below."
),
article=(
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
" <center><a href='https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset' target='_blank'>Fine-tuned Model</a></center></div>"
),
allow_flagging=False,
allow_screenshot=False,
).launch() |