Update codebase/inference/inference.py
Browse files
codebase/inference/inference.py
CHANGED
@@ -24,12 +24,12 @@ def build_model(model_name, ckpt_path, device):
|
|
24 |
if model_name == "ViT-B-32":
|
25 |
model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
|
26 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
27 |
-
msg = model.load_state_dict(checkpoint
|
28 |
|
29 |
elif model_name == "ViT-H-14":
|
30 |
model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k")
|
31 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
32 |
-
msg = model.load_state_dict(checkpoint
|
33 |
|
34 |
print(msg)
|
35 |
model = model.to(device)
|
|
|
24 |
if model_name == "ViT-B-32":
|
25 |
model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
|
26 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
27 |
+
msg = model.load_state_dict(checkpoint)
|
28 |
|
29 |
elif model_name == "ViT-H-14":
|
30 |
model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k")
|
31 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
32 |
+
msg = model.load_state_dict(checkpoint)
|
33 |
|
34 |
print(msg)
|
35 |
model = model.to(device)
|