Zilun commited on
Commit
6906236
1 Parent(s): 457dd08

Update codebase/inference/inference.py

Browse files
Files changed (1) hide show
  1. codebase/inference/inference.py +2 -2
codebase/inference/inference.py CHANGED
@@ -24,12 +24,12 @@ def build_model(model_name, ckpt_path, device):
24
  if model_name == "ViT-B-32":
25
  model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
26
  checkpoint = torch.load(ckpt_path, map_location="cpu")
27
- msg = model.load_state_dict(checkpoint, strict=False)
28
 
29
  elif model_name == "ViT-H-14":
30
  model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k")
31
  checkpoint = torch.load(ckpt_path, map_location="cpu")
32
- msg = model.load_state_dict(checkpoint, strict=False)
33
 
34
  print(msg)
35
  model = model.to(device)
 
24
  if model_name == "ViT-B-32":
25
  model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
26
  checkpoint = torch.load(ckpt_path, map_location="cpu")
27
+ msg = model.load_state_dict(checkpoint)
28
 
29
  elif model_name == "ViT-H-14":
30
  model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k")
31
  checkpoint = torch.load(ckpt_path, map_location="cpu")
32
+ msg = model.load_state_dict(checkpoint)
33
 
34
  print(msg)
35
  model = model.to(device)