ariG23498 HF staff commited on
Commit
9199a4f
β€’
1 Parent(s): 220f660

device allocation

Browse files
Files changed (1) hide show
  1. models/maskclip/maskclip.py +2 -1
models/maskclip/maskclip.py CHANGED
@@ -121,7 +121,8 @@ class MaskClipHead(nn.Module):
121
  return aug_embeddings.squeeze(1)
122
 
123
  def load_visual_projs(self):
124
- loaded = torch.load(self.visual_projs_path, map_location='cuda')
 
125
  attrs = ['proj']
126
  for attr in attrs:
127
  current_attr = getattr(self, attr)
 
121
  return aug_embeddings.squeeze(1)
122
 
123
  def load_visual_projs(self):
124
+ device = "cuda" if torch.cuda.is_available() else "cpu"
125
+ loaded = torch.load(self.visual_projs_path, map_location=device)
126
  attrs = ['proj']
127
  for attr in attrs:
128
  current_attr = getattr(self, attr)