hwjiang commited on
Commit
90d11eb
1 Parent(s): c140eb6

Update tsr/system.py

Browse files
Files changed (1) hide show
  1. tsr/system.py +10 -1
tsr/system.py CHANGED
@@ -67,7 +67,16 @@ class TSR(BaseModule):
67
  OmegaConf.resolve(cfg)
68
  model = cls(cfg)
69
  ckpt = torch.load(weight_path, map_location="cpu")
70
- model.load_state_dict(ckpt)
 
 
 
 
 
 
 
 
 
71
  return model
72
 
73
  def configure(self):
 
67
  OmegaConf.resolve(cfg)
68
  model = cls(cfg)
69
  ckpt = torch.load(weight_path, map_location="cpu")
70
+
71
+ if "module" in list(ckpt["state_dict"].keys())[0]:
72
+ state_dict = {key.replace('module.',''): item for key, item in checkpoint["state_dict"].items()}
73
+ else:
74
+ state_dict = ckpt["state_dict"]
75
+ missing_states = set(model.state_dict().keys()) - set(state_dict.keys())
76
+ if len(missing_states) > 0:
77
+ warnings.warn("Missing keys ! : {}".format(missing_states))
78
+
79
+ model.load_state_dict(state_dict)
80
  return model
81
 
82
  def configure(self):