Fix device error when using cuda

#4
by majinyu - opened
Files changed (1) hide show
  1. models/tag2text.py +2 -3
models/tag2text.py CHANGED
@@ -152,8 +152,7 @@ class RAM(nn.Module):
152
  self.class_threshold[key] = value
153
 
154
  def load_tag_list(self, tag_list_file):
155
- with open(tag_list_file, 'r', encoding="utf-8") as f:
156
- # with open(tag_list_file, 'r') as f:
157
  tag_list = f.read().splitlines()
158
  tag_list = np.array(tag_list)
159
  return tag_list
@@ -362,7 +361,7 @@ class Tag2Text_Caption(nn.Module):
362
  logits = self.fc(tagging_embed[0])
363
 
364
  targets = torch.where(
365
- torch.sigmoid(logits) > self.class_threshold,
366
  torch.tensor(1.0).to(image.device),
367
  torch.zeros(self.num_class).to(image.device))
368
 
 
152
  self.class_threshold[key] = value
153
 
154
  def load_tag_list(self, tag_list_file):
155
+ with open(tag_list_file, 'r', encoding="utf8") as f:
 
156
  tag_list = f.read().splitlines()
157
  tag_list = np.array(tag_list)
158
  return tag_list
 
361
  logits = self.fc(tagging_embed[0])
362
 
363
  targets = torch.where(
364
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
365
  torch.tensor(1.0).to(image.device),
366
  torch.zeros(self.num_class).to(image.device))
367