Alexander Bagus commited on
Commit
bbb38e8
·
1 Parent(s): 2527cf0
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -37,11 +37,8 @@ transformer = ZImageControlTransformer2DModel.from_pretrained(
37
 
38
  if TRANSFORMER_LOCAL is not None:
39
  print(f"From checkpoint: {TRANSFORMER_LOCAL}")
40
- if TRANSFORMER_LOCAL.endswith("safetensors"):
41
- from safetensors.torch import load_file, safe_open
42
- state_dict = load_file(TRANSFORMER_LOCAL)
43
- else:
44
- state_dict = torch.load(TRANSFORMER_LOCAL, map_location="cpu")
45
  state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
46
 
47
  m, u = transformer.load_state_dict(state_dict, strict=False)
 
37
 
38
  if TRANSFORMER_LOCAL is not None:
39
  print(f"From checkpoint: {TRANSFORMER_LOCAL}")
40
+ from safetensors.torch import load_file, safe_open
41
+ state_dict = load_file(TRANSFORMER_LOCAL)
 
 
 
42
  state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
43
 
44
  m, u = transformer.load_state_dict(state_dict, strict=False)