jykoh commited on
Commit
206f734
1 Parent(s): 6abad74

Fix model to load weights

Browse files
Files changed (1) hide show
  1. fromage/models.py +8 -8
fromage/models.py CHANGED
@@ -634,21 +634,21 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
634
  ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
- model_kwargs['opt_version'] = 'facebook/opt-125m'
638
- model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
639
  args = namedtuple('args', model_kwargs)(**model_kwargs)
640
 
641
  # Initialize model for inference.
642
  model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
643
  model = model.eval()
644
- # model = model.bfloat16()
645
- # model = model.cuda()
646
 
647
  # Load pretrained linear mappings and [RET] embeddings.
648
- # checkpoint = torch.load(model_ckpt_path)
649
- # model.load_state_dict(checkpoint['state_dict'], strict=False)
650
- # with torch.no_grad():
651
- # model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
652
 
653
  logit_scale = model.model.logit_scale.exp()
654
  emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
 
634
  ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
635
  assert len(ret_token_idx) == 1, ret_token_idx
636
  model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
+ # model_kwargs['opt_version'] = 'facebook/opt-125m'
638
+ # model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
639
  args = namedtuple('args', model_kwargs)(**model_kwargs)
640
 
641
  # Initialize model for inference.
642
  model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
643
  model = model.eval()
644
+ model = model.bfloat16()
645
+ model = model.cuda()
646
 
647
  # Load pretrained linear mappings and [RET] embeddings.
648
+ checkpoint = torch.load(model_ckpt_path)
649
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
650
+ with torch.no_grad():
651
+ model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
652
 
653
  logit_scale = model.model.logit_scale.exp()
654
  emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)