yyk19 commited on
Commit
e7a5f93
1 Parent(s): bc1f1f4
Files changed (1) hide show
  1. scripts/rendertext_tool.py +7 -6
scripts/rendertext_tool.py CHANGED
@@ -14,6 +14,12 @@ from torchvision.transforms import ToTensor
14
  from contextlib import nullcontext
15
 
16
  def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False):
 
 
 
 
 
 
17
  if ckpt.endswith("model_states.pt"):
18
  sd = torch.load(ckpt, map_location='cpu')["module"]
19
  else:
@@ -25,12 +31,7 @@ def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False):
25
  nk = k[7:]
26
  sd[nk] = sd[k]
27
  del sd[k]
28
-
29
- if "model_ema.input_blocks10in_layers0weight" not in sd:
30
- print("missing model_ema.input_blocks10in_layers0weight. set use_ema as False")
31
- cfg.model.params.use_ema = False
32
- model = instantiate_from_config(cfg.model)
33
-
34
  if not not_use_ckpt:
35
  m, u = model.load_state_dict(sd, strict=False)
36
  if len(m) > 0 and verbose:
 
14
  from contextlib import nullcontext
15
 
16
  def load_model_from_config(cfg, ckpt, verbose=False, not_use_ckpt=False):
17
+
18
+ if "model_ema.input_blocks10in_layers0weight" not in sd:
19
+ print("missing model_ema.input_blocks10in_layers0weight. set use_ema as False")
20
+ cfg.model.params.use_ema = False
21
+ model = instantiate_from_config(cfg.model)
22
+
23
  if ckpt.endswith("model_states.pt"):
24
  sd = torch.load(ckpt, map_location='cpu')["module"]
25
  else:
 
31
  nk = k[7:]
32
  sd[nk] = sd[k]
33
  del sd[k]
34
+
 
 
 
 
 
35
  if not not_use_ckpt:
36
  m, u = model.load_state_dict(sd, strict=False)
37
  if len(m) > 0 and verbose: