shin-mashita
Minor edits
7a783a2
raw
history blame
3.17 kB
import torch
import cv2
import videotransforms
import numpy as np
import gradio as gr
from einops import rearrange
from torchvision import transforms
from pytorch_i3d import InceptionI3d
def preprocess(vidpath):
cap = cv2.VideoCapture(vidpath)
frames = []
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
for _ in range(num):
_, img = cap.read()
w, h, c = img.shape
if w < 226 or h < 226:
d = 226. - min(w, h)
sc = 1 + d / min(w, h)
img = cv2.resize(img, dsize=(0, 0), fx=sc, fy=sc)
img = (img / 255.) * 2 - 1
frames.append(img)
# frames = torch.cuda.FloatTensor(np.asarray(frames, dtype=np.float32)) if torch.cuda.is_available() else torch.Tensor(np.asarray(frames, dtype=np.float32))
frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
transform = transforms.Compose([videotransforms.CenterCrop(224)])
frames = transform(frames)
frames = rearrange(frames, 't h w c-> 1 c t h w')
return frames
def classify(video,dataset='WLASL100'):
to_load = {
'WLASL100':{'logits':100,'path':'weights/asl100/FINAL_nslt_100_iters=896_top1=65.89_top5=84.11_top10=89.92.pt'},
'WLASL2000':{'logits':2000,'path':'weights/asl2000/FINAL_nslt_2000_iters=5104_top1=32.48_top5=57.31_top10=66.31.pt'}
}
input = preprocess(video)
model = InceptionI3d()
model.load_state_dict(torch.load('weights/rgb_imagenet.pt',map_location=torch.device('cpu')))
model.replace_logits(to_load[dataset]['logits'])
model.load_state_dict(torch.load(to_load[dataset]['path'],map_location=torch.device('cpu')))
# device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# model.to(device)
model.cpu()
model.eval()
with torch.no_grad():
per_frame_logits = model(input)
per_frame_logits.cpu()
model.cpu()
predictions = rearrange(per_frame_logits,'1 j k -> j k')
predictions = torch.mean(predictions, dim = 1)
top = torch.argmax(predictions).item()
_, index = torch.topk(predictions,10)
index = index.cpu().numpy()
with open('wlasl_class_list.txt') as f:
idx2label = dict()
for line in f:
idx2label[int(line.split()[0])]=line.split()[1]
predictions = torch.nn.functional.softmax(predictions, dim=0).cpu().numpy()
return {idx2label[i]:float(predictions[i]) for i in index}
title = "I3D Sign Language Recognition"
description = "Description here"
examples = [['videos/no.mp4','WLASL100'],['videos/all.mp4','WLASL100'],['videos/blue.mp4','WLASL2000'],['videos/white.mp4','WLASL2000'],['videos/accident.mp4','WLASL2000']]
gr.Interface( fn=classify,
inputs=[gr.inputs.Video(label="VIDEO"),gr.inputs.Dropdown(choices=['WLASL100','WLASL2000'], default='WLASL100', label='DATASET USED')],
outputs=[gr.outputs.Label(num_top_classes=5, label='Top 5 Predictions')],
allow_flagging="never",
title=title,
description=description,
examples=examples).launch()