p1atdev commited on
Commit
0fbabef
1 Parent(s): 898b3fa

Upload extract_controlnet.py

Browse files
Files changed (1) hide show
  1. extract_controlnet.py +39 -0
extract_controlnet.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from safetensors.torch import load_file, save_file
4
+
5
+ if __name__ == "__main__":
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--src", default=None, type=str, required=True, help="Path to the model to convert.")
8
+ parser.add_argument("--dst", default=None, type=str, required=True, help="Path to the output model.")
9
+ parser.add_argument("--fp16", action="store_true", help="Whether to convert the model to fp16.")
10
+ args = parser.parse_args()
11
+
12
+ assert args.src is not None, "Must provide a model path!"
13
+ assert args.dst is not None, "Must provide a checkpoint path!"
14
+
15
+ if args.src.endswith(".safetensors"):
16
+ state_dict = load_file(args.src, map_location="cpu")
17
+ else:
18
+ state_dict = torch.load(args.src, map_location="cpu")
19
+
20
+ try:
21
+ state_dict = state_dict['state_dict']["state_dict"]
22
+ except:
23
+ try:
24
+ state_dict = state_dict['state_dict']
25
+ except:
26
+ pass
27
+
28
+ if args.fp16:
29
+ if any([k.startswith("control_model.") for k, v in state_dict.items()]):
30
+ state_dict = {k.replace("control_model.", ""): v.half() for k, v in state_dict.items() if k.startswith("control_model.")}
31
+ else:
32
+ if any([k.startswith("control_model.") for k, v in state_dict.items()]):
33
+ state_dict = {k.replace("control_model.", ""): v for k, v in state_dict.items() if k.startswith("control_model.")}
34
+
35
+
36
+ if args.dst.endswith(".safetensors"):
37
+ save_file(state_dict, args.dst)
38
+ else:
39
+ torch.save({"state_dict": state_dict}, args.dst)