shin-mashita commited on
Commit
7a783a2
1 Parent(s): 9754b6d

Minor edits

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -27,6 +27,7 @@ def preprocess(vidpath):
27
 
28
  frames.append(img)
29
 
 
30
  frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
31
 
32
  transform = transforms.Compose([videotransforms.CenterCrop(224)])
@@ -44,28 +45,34 @@ def classify(video,dataset='WLASL100'):
44
  input = preprocess(video)
45
 
46
  model = InceptionI3d()
47
- model.cpu()
48
- model.load_state_dict(torch.load('weights/rgb_imagenet.pt'))
49
  model.replace_logits(to_load[dataset]['logits'])
50
- model.load_state_dict(torch.load(to_load[dataset]['path']))
 
 
 
 
51
  model.eval()
52
 
53
  with torch.no_grad():
54
  per_frame_logits = model(input)
55
 
 
 
 
56
  predictions = rearrange(per_frame_logits,'1 j k -> j k')
57
  predictions = torch.mean(predictions, dim = 1)
58
 
59
  top = torch.argmax(predictions).item()
60
  _, index = torch.topk(predictions,10)
61
- index = index.numpy()
62
 
63
  with open('wlasl_class_list.txt') as f:
64
  idx2label = dict()
65
  for line in f:
66
  idx2label[int(line.split()[0])]=line.split()[1]
67
 
68
- predictions = torch.nn.functional.softmax(predictions, dim=0).numpy()
69
 
70
  return {idx2label[i]:float(predictions[i]) for i in index}
71
 
 
27
 
28
  frames.append(img)
29
 
30
+ # frames = torch.cuda.FloatTensor(np.asarray(frames, dtype=np.float32)) if torch.cuda.is_available() else torch.Tensor(np.asarray(frames, dtype=np.float32))
31
  frames = torch.Tensor(np.asarray(frames, dtype=np.float32))
32
 
33
  transform = transforms.Compose([videotransforms.CenterCrop(224)])
 
45
  input = preprocess(video)
46
 
47
  model = InceptionI3d()
48
+ model.load_state_dict(torch.load('weights/rgb_imagenet.pt',map_location=torch.device('cpu')))
 
49
  model.replace_logits(to_load[dataset]['logits'])
50
+ model.load_state_dict(torch.load(to_load[dataset]['path'],map_location=torch.device('cpu')))
51
+
52
+ # device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
53
+ # model.to(device)
54
+ model.cpu()
55
  model.eval()
56
 
57
  with torch.no_grad():
58
  per_frame_logits = model(input)
59
 
60
+ per_frame_logits.cpu()
61
+ model.cpu()
62
+
63
  predictions = rearrange(per_frame_logits,'1 j k -> j k')
64
  predictions = torch.mean(predictions, dim = 1)
65
 
66
  top = torch.argmax(predictions).item()
67
  _, index = torch.topk(predictions,10)
68
+ index = index.cpu().numpy()
69
 
70
  with open('wlasl_class_list.txt') as f:
71
  idx2label = dict()
72
  for line in f:
73
  idx2label[int(line.split()[0])]=line.split()[1]
74
 
75
+ predictions = torch.nn.functional.softmax(predictions, dim=0).cpu().numpy()
76
 
77
  return {idx2label[i]:float(predictions[i]) for i in index}
78