ktrk115 commited on
Commit
d76abce
1 Parent(s): 5c75040

Upload push_to_hub.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. push_to_hub.py +76 -0
push_to_hub.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import huggingface_hub
4
+ import torch
5
+ from vqmodel.configuration_vqmodel import VQModelConfig
6
+ from vqmodel.image_processing_vqmodel import VQModelImageProcessor
7
+ from vqmodel.modeling_vqmodel import VQModel
8
+
9
+ VQModelConfig.register_for_auto_class()
10
+ VQModel.register_for_auto_class()
11
+ VQModelImageProcessor.register_for_auto_class()
12
+
13
+
14
+ def main():
15
+ args = parse_args()
16
+ config = VQModelConfig(yaml_path=args.yaml_path)
17
+ model = VQModel(config)
18
+ load_model_weights(model, args.ckpt_path)
19
+
20
+ # Define image processor
21
+ ddconfig = model.vq_cfg.model.params.ddconfig
22
+ image_processor = VQModelImageProcessor(
23
+ size=ddconfig.resolution,
24
+ convert_rgb=ddconfig.in_channels == 3,
25
+ )
26
+
27
+ # Edit config
28
+ model.config.repo_id = args.repo_id
29
+ model.config.yaml_path = "config.yaml"
30
+
31
+ # Push to hub
32
+ model.push_to_hub(args.repo_id, private=True)
33
+ image_processor.push_to_hub(args.repo_id, private=True)
34
+ api = huggingface_hub.HfApi()
35
+ api.upload_file(
36
+ path_or_fileobj=args.yaml_path,
37
+ path_in_repo="config.yaml",
38
+ repo_id=args.repo_id,
39
+ )
40
+ api.upload_file(
41
+ path_or_fileobj=__file__,
42
+ path_in_repo="push_to_hub.py",
43
+ repo_id=args.repo_id,
44
+ )
45
+ api.upload_file(
46
+ path_or_fileobj="requirements.txt",
47
+ path_in_repo="requirements.txt",
48
+ repo_id=args.repo_id,
49
+ )
50
+
51
+
52
+ def parse_args():
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--repo_id", type=str, required=True, help="Repository ID")
55
+ parser.add_argument(
56
+ "--yaml_path", type=str, required=True, help="Path to YAML file"
57
+ )
58
+ parser.add_argument(
59
+ "--ckpt_path", type=str, required=True, help="Path to checkpoint file"
60
+ )
61
+ return parser.parse_args()
62
+
63
+
64
+ def load_model_weights(model, ckpt_path):
65
+ # Load checkpoint
66
+ ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
67
+
68
+ # Remove loss related states
69
+ for key in list(ckpt.keys()):
70
+ if key.startswith("loss."):
71
+ del ckpt[key]
72
+ model.model.load_state_dict(ckpt)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()