bingyikang commited on
Commit
72960a7
1 Parent(s): 192e5fb

fix proj moel

Browse files
bubogpt/models/mm_gpt4.py CHANGED
@@ -276,7 +276,7 @@ class MMGPT4(BaseModel):
276
  with_bind_head = cfg.get("with_bind_head", False)
277
  freeze_llm = cfg.get("freeze_llm", True)
278
  use_blip_vision = cfg.get("use_blip_vision", False)
279
- proj_model = cfg.get("proj_model", "checkpoints/prerained_minigpt4_7b.pth")
280
 
281
  model = cls(
282
  joiner_cfg=joiner_cfg,
 
276
  with_bind_head = cfg.get("with_bind_head", False)
277
  freeze_llm = cfg.get("freeze_llm", True)
278
  use_blip_vision = cfg.get("use_blip_vision", False)
279
+ proj_model = cfg.get("proj_model", "")
280
 
281
  model = cls(
282
  joiner_cfg=joiner_cfg,
imagebind/models/image_bind.py CHANGED
@@ -656,8 +656,9 @@ def replace_joiner_vision(joiner, q_former_model, proj_model):
656
 
657
  joiner.modality_qformers[ModalityType.VISION].load_Qformer(q_former_model)
658
 
659
- state_dict = torch.load(proj_model, map_location="cpu")["model"]
660
- params = type(state_dict)()
661
- params["fc.weight"] = state_dict["llama_proj.weight"]
662
- params["fc.bias"] = state_dict["llama_proj.bias"]
663
- joiner.modality_post_projectors[ModalityType.VISION].load_state_dict(params, strict=False)
 
 
656
 
657
  joiner.modality_qformers[ModalityType.VISION].load_Qformer(q_former_model)
658
 
659
+ if proj_model:
660
+ state_dict = torch.load(proj_model, map_location="cpu")["model"]
661
+ params = type(state_dict)()
662
+ params["fc.weight"] = state_dict["llama_proj.weight"]
663
+ params["fc.bias"] = state_dict["llama_proj.bias"]
664
+ joiner.modality_post_projectors[ModalityType.VISION].load_state_dict(params, strict=False)