Spaces:
Runtime error
Runtime error
Hui
commited on
Commit
•
0b1f893
1
Parent(s):
91b59cb
fix model
Browse files
app.py
CHANGED
@@ -23,11 +23,12 @@ transform = transforms.Compose([transforms.ToTensor(),
|
|
23 |
|
24 |
# model imports
|
25 |
def load_pretrained_params(model, model_state_path: str):
|
26 |
-
|
|
|
27 |
model_dict = model.state_dict()
|
28 |
# 1. filter out unnecessary keys
|
29 |
-
if list(pretrained_dict.keys())[0].startswith("
|
30 |
-
pretrained_dict = {k[
|
31 |
else:
|
32 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
33 |
# 2. overwrite entries in the existing state dict
|
|
|
23 |
|
24 |
# model imports
|
25 |
def load_pretrained_params(model, model_state_path: str):
|
26 |
+
checkpoint = torch.load(model_state_path, map_location="cpu")
|
27 |
+
pretrained_dict = checkpoint["state_dict"]
|
28 |
model_dict = model.state_dict()
|
29 |
# 1. filter out unnecessary keys
|
30 |
+
if list(pretrained_dict.keys())[0].startswith("model."):
|
31 |
+
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
|
32 |
else:
|
33 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
34 |
# 2. overwrite entries in the existing state dict
|