Spaces:
Build error
Build error
Fix model to load weights
Browse files- fromage/models.py +8 -8
fromage/models.py
CHANGED
@@ -634,21 +634,21 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
|
|
634 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
637 |
-
model_kwargs['opt_version'] = 'facebook/opt-125m'
|
638 |
-
model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
639 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
640 |
|
641 |
# Initialize model for inference.
|
642 |
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
|
643 |
model = model.eval()
|
644 |
-
|
645 |
-
|
646 |
|
647 |
# Load pretrained linear mappings and [RET] embeddings.
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
|
653 |
logit_scale = model.model.logit_scale.exp()
|
654 |
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|
|
|
634 |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
635 |
assert len(ret_token_idx) == 1, ret_token_idx
|
636 |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
637 |
+
# model_kwargs['opt_version'] = 'facebook/opt-125m'
|
638 |
+
# model_kwargs['visual_encoder'] = 'openai/clip-vit-base-patch32'
|
639 |
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
640 |
|
641 |
# Initialize model for inference.
|
642 |
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
|
643 |
model = model.eval()
|
644 |
+
model = model.bfloat16()
|
645 |
+
model = model.cuda()
|
646 |
|
647 |
# Load pretrained linear mappings and [RET] embeddings.
|
648 |
+
checkpoint = torch.load(model_ckpt_path)
|
649 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
650 |
+
with torch.no_grad():
|
651 |
+
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
|
652 |
|
653 |
logit_scale = model.model.logit_scale.exp()
|
654 |
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|