Hui commited on
Commit
0b1f893
1 Parent(s): 91b59cb
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -23,11 +23,12 @@ transform = transforms.Compose([transforms.ToTensor(),
23
 
24
  # model imports
25
  def load_pretrained_params(model, model_state_path: str):
26
- pretrained_dict = torch.load(model_state_path, map_location="cpu")
 
27
  model_dict = model.state_dict()
28
  # 1. filter out unnecessary keys
29
- if list(pretrained_dict.keys())[0].startswith("module."):
30
- pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
31
  else:
32
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
33
  # 2. overwrite entries in the existing state dict
 
23
 
24
  # model imports
25
  def load_pretrained_params(model, model_state_path: str):
26
+ checkpoint = torch.load(model_state_path, map_location="cpu")
27
+ pretrained_dict = checkpoint["state_dict"]
28
  model_dict = model.state_dict()
29
  # 1. filter out unnecessary keys
30
+ if list(pretrained_dict.keys())[0].startswith("model."):
31
+ pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
32
  else:
33
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
34
  # 2. overwrite entries in the existing state dict