hvaldez commited on
Commit
ee924e5
1 Parent(s): c9ca468

updating for gpu

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. demo.py +9 -5
app.py CHANGED
@@ -53,6 +53,7 @@ def main():
53
  "configs/ego_mcq/svitt.yml",
54
  sample_videos,
55
  )
 
56
  def predict(text):
57
  idx = sample_text_dict[text]
58
  ft_action, gt_action = svitt.predict(idx, text)
 
53
  "configs/ego_mcq/svitt.yml",
54
  sample_videos,
55
  )
56
+
57
  def predict(text):
58
  idx = sample_text_dict[text]
59
  ft_action, gt_action = svitt.predict(idx, text)
demo.py CHANGED
@@ -24,9 +24,13 @@ class VideoModel(nn.Module):
24
  Parameters:
25
  config: config file
26
  """
27
- super(VideoModel, self).__init__()
28
  self.cfg = load_cfg(config)
29
  self.model = self.build_model()
 
 
 
 
30
  self.templates = ['{}']
31
  self.dataset = self.cfg['data']['dataset']
32
  self.eval()
@@ -74,7 +78,7 @@ class VideoModel(nn.Module):
74
  class VideoCLSModel(VideoModel):
75
  """ Video model for video classification tasks (Charades-Ego, EGTEA). """
76
  def __init__(self, config, sample_videos):
77
- super(VideoCLSModel, self).__init__(config)
78
  self.sample_videos = sample_videos
79
  self.video_transform = self.init_video_transform()
80
 
@@ -125,7 +129,7 @@ class VideoCLSModel(VideoModel):
125
  truncation=True,
126
  max_length=self.model_cfg.max_txt_l.video,
127
  return_tensors="pt",
128
- )
129
  _, class_embeddings = self.model.encode_text(embeddings)
130
  return class_embeddings
131
 
@@ -143,7 +147,7 @@ class VideoCLSModel(VideoModel):
143
  pooled_image_feat_all = []
144
  for i in range(clips.shape[0]):
145
 
146
- images = clips[i,:].unsqueeze(0)
147
  bsz = images.shape[0]
148
 
149
  _, pooled_image_feat, *outputs = self.model.encode_image(images)
@@ -161,5 +165,5 @@ class VideoCLSModel(VideoModel):
161
  @torch.no_grad()
162
  def predict(self, idx, text=None):
163
  output, target = self.forward(idx, text)
164
- return output.numpy(), target
165
 
 
24
  Parameters:
25
  config: config file
26
  """
27
+ super().__init__()
28
  self.cfg = load_cfg(config)
29
  self.model = self.build_model()
30
+ use_gpu = torch.cuda.is_available()
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ if use_gpu:
33
+ self.model = self.model.to(self.device)
34
  self.templates = ['{}']
35
  self.dataset = self.cfg['data']['dataset']
36
  self.eval()
 
78
  class VideoCLSModel(VideoModel):
79
  """ Video model for video classification tasks (Charades-Ego, EGTEA). """
80
  def __init__(self, config, sample_videos):
81
+ super().__init__(config)
82
  self.sample_videos = sample_videos
83
  self.video_transform = self.init_video_transform()
84
 
 
129
  truncation=True,
130
  max_length=self.model_cfg.max_txt_l.video,
131
  return_tensors="pt",
132
+ ).to(self.device)
133
  _, class_embeddings = self.model.encode_text(embeddings)
134
  return class_embeddings
135
 
 
147
  pooled_image_feat_all = []
148
  for i in range(clips.shape[0]):
149
 
150
+ images = clips[i,:].unsqueeze(0).to(self.device)
151
  bsz = images.shape[0]
152
 
153
  _, pooled_image_feat, *outputs = self.model.encode_image(images)
 
165
  @torch.no_grad()
166
  def predict(self, idx, text=None):
167
  output, target = self.forward(idx, text)
168
+ return output.cpu().numpy(), target
169