nick_93 commited on
Commit
8f4661f
1 Parent(s): ee23f4b
Files changed (2) hide show
  1. app.py +1 -1
  2. depth/models_depth/model.py +5 -4
app.py CHANGED
@@ -78,7 +78,7 @@ def main():
78
  model = EVPDepth(args=args, caption_aggregation=True)
79
  cudnn.benchmark = True
80
  model.to(device)
81
- model_weight = torch.load(args.ckpt_dir)['model']
82
  if 'module' in next(iter(model_weight.items()))[0]:
83
  model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
84
  model.load_state_dict(model_weight, strict=False)
 
78
  model = EVPDepth(args=args, caption_aggregation=True)
79
  cudnn.benchmark = True
80
  model.to(device)
81
+ model_weight = torch.load(args.ckpt_dir, map_location=device)['model']
82
  if 'module' in next(iter(model_weight.items()))[0]:
83
  model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
84
  model.load_state_dict(model_weight, strict=False)
depth/models_depth/model.py CHANGED
@@ -310,9 +310,10 @@ class EVPDepthEncoder(nn.Module):
310
 
311
  self.text_adapter = TextAdapterRefer(text_dim=text_dim)
312
  self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4)
313
-
 
314
  if caption_aggregation:
315
- class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth')
316
  #class_embeddings_list = [value['class_embeddings'] for key, value in class_embeddings.items()]
317
  #stacked_embeddings = torch.stack(class_embeddings_list, dim=0)
318
  #class_embeddings = torch.mean(stacked_embeddings, dim=0).unsqueeze(0)
@@ -320,7 +321,7 @@ class EVPDepthEncoder(nn.Module):
320
  if 'aggregated' in class_embeddings:
321
  class_embeddings = class_embeddings['aggregated']
322
  else:
323
- clip_model = FrozenCLIPEmbedder(max_length=40,pool=False).cuda()
324
  class_embeddings_new = [clip_model.encode(value['caption'][0]) for key, value in class_embeddings.items()]
325
  class_embeddings_new = torch.mean(torch.stack(class_embeddings_new, dim=0), dim=0)
326
  class_embeddings['aggregated'] = class_embeddings_new
@@ -328,7 +329,7 @@ class EVPDepthEncoder(nn.Module):
328
  class_embeddings = class_embeddings['aggregated']
329
  self.register_buffer('class_embeddings', class_embeddings)
330
  else:
331
- self.class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth')
332
 
333
  self.clip_model = FrozenCLIPEmbedder(max_length=40,pool=False)
334
  for param in self.clip_model.parameters():
 
310
 
311
  self.text_adapter = TextAdapterRefer(text_dim=text_dim)
312
  self.gamma = nn.Parameter(torch.ones(text_dim) * 1e-4)
313
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
314
+
315
  if caption_aggregation:
316
+ class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth', map_location=device)
317
  #class_embeddings_list = [value['class_embeddings'] for key, value in class_embeddings.items()]
318
  #stacked_embeddings = torch.stack(class_embeddings_list, dim=0)
319
  #class_embeddings = torch.mean(stacked_embeddings, dim=0).unsqueeze(0)
 
321
  if 'aggregated' in class_embeddings:
322
  class_embeddings = class_embeddings['aggregated']
323
  else:
324
+ clip_model = FrozenCLIPEmbedder(max_length=40,pool=False).to(device)
325
  class_embeddings_new = [clip_model.encode(value['caption'][0]) for key, value in class_embeddings.items()]
326
  class_embeddings_new = torch.mean(torch.stack(class_embeddings_new, dim=0), dim=0)
327
  class_embeddings['aggregated'] = class_embeddings_new
 
329
  class_embeddings = class_embeddings['aggregated']
330
  self.register_buffer('class_embeddings', class_embeddings)
331
  else:
332
+ self.class_embeddings = torch.load(f'{dataset}_class_embeddings_my_captions.pth', map_location=device)
333
 
334
  self.clip_model = FrozenCLIPEmbedder(max_length=40,pool=False)
335
  for param in self.clip_model.parameters():