shin-mashita
commited on
Commit
•
7a783a2
1
Parent(s):
9754b6d
Minor edits
Browse files
app.py
CHANGED
@@ -27,6 +27,7 @@ def preprocess(vidpath):
|
|
27 |
|
28 |
frames.append(img)
|
29 |
|
|
|
30 |
frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
|
31 |
|
32 |
transform = transforms.Compose([videotransforms.CenterCrop(224)])
|
@@ -44,28 +45,34 @@ def classify(video,dataset='WLASL100'):
|
|
44 |
input = preprocess(video)
|
45 |
|
46 |
model = InceptionI3d()
|
47 |
-
model.cpu
|
48 |
-
model.load_state_dict(torch.load('weights/rgb_imagenet.pt'))
|
49 |
model.replace_logits(to_load[dataset]['logits'])
|
50 |
-
model.load_state_dict(torch.load(to_load[dataset]['path']))
|
|
|
|
|
|
|
|
|
51 |
model.eval()
|
52 |
|
53 |
with torch.no_grad():
|
54 |
per_frame_logits = model(input)
|
55 |
|
|
|
|
|
|
|
56 |
predictions = rearrange(per_frame_logits,'1 j k -> j k')
|
57 |
predictions = torch.mean(predictions, dim = 1)
|
58 |
|
59 |
top = torch.argmax(predictions).item()
|
60 |
_, index = torch.topk(predictions,10)
|
61 |
-
index = index.numpy()
|
62 |
|
63 |
with open('wlasl_class_list.txt') as f:
|
64 |
idx2label = dict()
|
65 |
for line in f:
|
66 |
idx2label[int(line.split()[0])]=line.split()[1]
|
67 |
|
68 |
-
predictions = torch.nn.functional.softmax(predictions, dim=0).numpy()
|
69 |
|
70 |
return {idx2label[i]:float(predictions[i]) for i in index}
|
71 |
|
|
|
27 |
|
28 |
frames.append(img)
|
29 |
|
30 |
+
# frames = torch.cuda.FloatTensor(np.asarray(frames, dtype=np.float32)) if torch.cuda.is_available() else torch.Tensor(np.asarray(frames, dtype=np.float32))
|
31 |
frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
|
32 |
|
33 |
transform = transforms.Compose([videotransforms.CenterCrop(224)])
|
|
|
45 |
input = preprocess(video)
|
46 |
|
47 |
model = InceptionI3d()
|
48 |
+
model.load_state_dict(torch.load('weights/rgb_imagenet.pt',map_location=torch.device('cpu')))
|
|
|
49 |
model.replace_logits(to_load[dataset]['logits'])
|
50 |
+
model.load_state_dict(torch.load(to_load[dataset]['path'],map_location=torch.device('cpu')))
|
51 |
+
|
52 |
+
# device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
53 |
+
# model.to(device)
|
54 |
+
model.cpu()
|
55 |
model.eval()
|
56 |
|
57 |
with torch.no_grad():
|
58 |
per_frame_logits = model(input)
|
59 |
|
60 |
+
per_frame_logits.cpu()
|
61 |
+
model.cpu()
|
62 |
+
|
63 |
predictions = rearrange(per_frame_logits,'1 j k -> j k')
|
64 |
predictions = torch.mean(predictions, dim = 1)
|
65 |
|
66 |
top = torch.argmax(predictions).item()
|
67 |
_, index = torch.topk(predictions,10)
|
68 |
+
index = index.cpu().numpy()
|
69 |
|
70 |
with open('wlasl_class_list.txt') as f:
|
71 |
idx2label = dict()
|
72 |
for line in f:
|
73 |
idx2label[int(line.split()[0])]=line.split()[1]
|
74 |
|
75 |
+
predictions = torch.nn.functional.softmax(predictions, dim=0).cpu().numpy()
|
76 |
|
77 |
return {idx2label[i]:float(predictions[i]) for i in index}
|
78 |
|