SynLayers commited on
Commit
3a8ebe3
·
verified ·
1 Parent(s): 6cbd779

Upload tools/download_ckpt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tools/download_ckpt.py +114 -0
tools/download_ckpt.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ from huggingface_hub import snapshot_download
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+ DEFAULT_SYNLAYERS_REPO = "thuteam/" + "C" + "LD"
10
+
11
+
12
+ def parse_args():
13
+ parser = argparse.ArgumentParser(description="Download SynLayers pretrained checkpoints.")
14
+ parser.add_argument(
15
+ "--project-root",
16
+ default=str(Path(__file__).resolve().parents[1]),
17
+ help="SynLayers project root directory.",
18
+ )
19
+ parser.add_argument(
20
+ "--flux-dir",
21
+ default=None,
22
+ help="Output directory for FLUX.1-dev weights.",
23
+ )
24
+ parser.add_argument(
25
+ "--adapter-dir",
26
+ default=None,
27
+ help="Output directory for FLUX.1-dev ControlNet adapter weights.",
28
+ )
29
+ parser.add_argument(
30
+ "--download-synlayers",
31
+ dest="download_synlayers",
32
+ action="store_true",
33
+ help="Download SynLayers ckpt folder into project root.",
34
+ )
35
+ parser.add_argument(
36
+ "--synlayers-repo",
37
+ dest="synlayers_repo",
38
+ default=DEFAULT_SYNLAYERS_REPO,
39
+ help="Hugging Face repo for SynLayers-compatible checkpoints.",
40
+ )
41
+ parser.add_argument(
42
+ "--skip-flux",
43
+ action="store_true",
44
+ help="Skip downloading FLUX.1-dev weights.",
45
+ )
46
+ parser.add_argument(
47
+ "--skip-adapter",
48
+ action="store_true",
49
+ help="Skip downloading FLUX.1-dev ControlNet adapter weights.",
50
+ )
51
+ return parser.parse_args()
52
+
53
+
54
+ def download_flux(target_dir):
55
+ logger.info("Downloading FLUX.1-dev to %s", target_dir)
56
+ snapshot_download(
57
+ repo_id="black-forest-labs/FLUX.1-dev",
58
+ local_dir=str(target_dir),
59
+ local_dir_use_symlinks=False,
60
+ )
61
+
62
+
63
+ def download_adapter(target_dir):
64
+ logger.info("Downloading FLUX.1-dev-Controlnet-Inpainting-Alpha to %s", target_dir)
65
+ snapshot_download(
66
+ repo_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
67
+ local_dir=str(target_dir),
68
+ local_dir_use_symlinks=False,
69
+ )
70
+
71
+
72
+ def download_synlayers_ckpt(project_root, repo_id):
73
+ ckpt_dir = project_root / "ckpt"
74
+ logger.info("Downloading SynLayers ckpt files into %s", ckpt_dir)
75
+ snapshot_download(
76
+ repo_id=repo_id,
77
+ local_dir=str(ckpt_dir),
78
+ allow_patterns=[
79
+ "decouple_LoRA/**",
80
+ "pre_trained_LoRA/**",
81
+ "prism_ft_LoRA/**",
82
+ "trans_vae/**",
83
+ "README.md",
84
+ ],
85
+ local_dir_use_symlinks=False,
86
+ )
87
+
88
+
89
+ def main():
90
+ logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
91
+ args = parse_args()
92
+ logger.info("Note: dataset JSONL sizes should be multiples of 8 for inference.")
93
+ project_root = Path(args.project_root).resolve()
94
+ default_ckpt_root = project_root / "ckpt"
95
+ flux_dir = Path(args.flux_dir) if args.flux_dir else default_ckpt_root / "FLUX.1-dev"
96
+ adapter_dir = (
97
+ Path(args.adapter_dir)
98
+ if args.adapter_dir
99
+ else default_ckpt_root / "FLUX.1-dev-Controlnet-Inpainting-Alpha"
100
+ )
101
+
102
+ if not args.skip_flux:
103
+ flux_dir.mkdir(parents=True, exist_ok=True)
104
+ download_flux(flux_dir)
105
+ if not args.skip_adapter:
106
+ adapter_dir.mkdir(parents=True, exist_ok=True)
107
+ download_adapter(adapter_dir)
108
+ if args.download_synlayers:
109
+ download_synlayers_ckpt(project_root, args.synlayers_repo)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
114
+