John6666 commited on
Commit
908a50c
1 Parent(s): 92866e7

Upload 9 files

Browse files
Files changed (9) hide show
  1. README.md +14 -12
  2. app.py +541 -0
  3. convert_url_to_diffusers_multi_gr.py +466 -0
  4. packages.txt +1 -0
  5. presets.py +147 -0
  6. requirements.txt +12 -0
  7. sdutils.py +170 -0
  8. stkey.py +122 -0
  9. utils.py +297 -0
README.md CHANGED
@@ -1,12 +1,14 @@
1
- ---
2
- title: Gradio Uitest1
3
- emoji: 🌍
4
- colorFrom: pink
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ ---
2
+ title: Download safetensors and convert to HF🤗 Diffusers format (SDXL / SD 1.5 / FLUX.1 / SD 3.5) Alpha
3
+ emoji: 🎨➡️🧨
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Convert SDXL/1.5/3.5/FLUX.1 safetensors to HF🤗 Diffusers
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from convert_url_to_diffusers_multi_gr import convert_url_to_diffusers_repo, get_dtypes, FLUX_BASE_REPOS, SD35_BASE_REPOS
3
+ from presets import (DEFAULT_DTYPE, schedulers, clips, t5s, sdxl_vaes, sdxl_loras, sdxl_preset_dict, sdxl_set_presets,
4
+ sd15_vaes, sd15_loras, sd15_preset_dict, sd15_set_presets, flux_vaes, flux_loras, flux_preset_dict, flux_set_presets,
5
+ sd35_vaes, sd35_loras, sd35_preset_dict, sd35_set_presets)
6
+ import os
7
+
8
+
9
+ HF_USER = os.getenv("HF_USER", "")
10
+ HF_REPO = os.getenv("HF_REPO", "")
11
+ HF_URL = os.getenv("HF_URL", "")
12
+ HF_OW = os.getenv("HF_OW", False)
13
+ HF_PR = os.getenv("HF_PR", False)
14
+
15
+ css = """
16
+ .title { font-size: 3em; align-items: center; text-align: center; }
17
+ .info { align-items: center; text-align: center; }
18
+ .block.result { margin: 1em 0; padding: 1em; box-shadow: 0 0 3px 3px #664422, 0 0 3px 2px #664422 inset; border-radius: 6px; background: #665544; }
19
+
20
+ .block.result p,
21
+ .block.result .prose,
22
+ .block-result * {
23
+ color: aquamarine;
24
+ font-family: "Atma", monospace !important;
25
+ font-size: 1.06em;
26
+ letter-spacing: 0.50px;
27
+ word-spacing: 6px;
28
+ }
29
+ /* I know that I use a lot of !important instruction..
30
+ Yeah, I know it's not really properly recommended to use !important instruction when writing CSS rules ..
31
+ But it's because that Gradio interface use also a custom css-theme in gr.Blocks() from: "NoCrypt/miku@>=1.2.2"
32
+ And all my !important instructions are for avoiding to get unexpected superceeded by that "NoCrypt/miku@>=1.2.2" css theme skin...
33
+ No more, no less!
34
+ */
35
+ .setting_tag::before {
36
+ content: "::";
37
+ font-family: system-ui, monospace !important;
38
+ font-size: 1.2em !important;
39
+ color: #736819 !important;
40
+ font-weight: bold;
41
+ display: block;
42
+ text-align: center !important;
43
+ margin: 0 auto !important;
44
+ }
45
+ .setting_tag {
46
+ margin-top: 1em !important;
47
+ font-size: 1.4em !important;
48
+ color: darkgoldenrod !important;
49
+ padding: 10px !important;
50
+ font-weight: normal !important;
51
+ border: 2px dotted gold !important;
52
+ background: white !important;
53
+ word-break: break-word !important;
54
+ display: block;
55
+ text-align: center !important;
56
+ margin: 0 auto !important;
57
+ text-shadow: 1px 1px 1px gold !important;
58
+ }
59
+ .setting_tag::after {
60
+ content: "::";
61
+ font-family: system-ui, monospace !important;
62
+ font-size: 1.2em !important;
63
+ color: #736819 !important;
64
+ font-weight: bold !important;
65
+ display: block !important;
66
+ text-align: center !important;
67
+ margin: 0 auto !important;
68
+ }
69
+ a.linkify_1,
70
+ a.linkify_1:focus,
71
+ a.linkify_1:visited {
72
+ color: red !important;
73
+ text-decoration: underline !important;
74
+ }
75
+ a.linkify_1:active,
76
+ a.linkify_1:hover {
77
+ color: darkred !important;
78
+ text-decoration: overline !important;
79
+ }
80
+ .details_info_block_expanded_override {
81
+ font-size: 0.95em !important;
82
+ color: grey !important;
83
+ padding: 10px !important;
84
+ font-weight: bold;
85
+ border: none !important;
86
+ box-shadow: 1px 1px 2px 3px whitesmoke, 3px 2px 1px 1px grey !important;
87
+ background: whitesmoke !important;
88
+ word-break: break-word !important;
89
+ }
90
+ .details_info_block {
91
+ font-size: 0.95em !important;
92
+ color: grey !important;
93
+ padding: 10px !important;
94
+ font-weight: bold;
95
+ border: none !important;
96
+ box-shadow: 1px 1px 2px 3px whitesmoke, 3px 2px 1px 1px grey !important;
97
+ background: whitesmoke !important;
98
+ word-break: break-word !important;
99
+ height: initial !important;
100
+ max-height: initial !important;
101
+ overflow: auto !important;
102
+ }
103
+ .details_info_block[is-expanded="False"]::before {
104
+ content: "(double-click to expand help...)";
105
+ user-select: none;
106
+ cursor: pointer;
107
+ font-family: "Atma" !important;
108
+ font-weight: 500 !important;
109
+ letter-spacing: 2px !important;
110
+ word-spacing: 5px !important;
111
+ color: darkmagenta !important;
112
+ display: inline-block !important;
113
+ font-size: 0.99em !important;
114
+ background: whitesmoke !important;
115
+ padding: 10px !important;
116
+ border: 2px solid black !important;
117
+ }
118
+ .details_info_block[is-expanded="False"] {
119
+ transition: 1.2s all;
120
+ height: 80px !important;
121
+ max-height: 80px !important;
122
+ overflow: hidden !important;
123
+ }
124
+ .details_info_block[is-expanded="True"]::before {
125
+ content: "(double-click to reduce help...)";
126
+ user-select: none;
127
+ cursor: pointer;
128
+ font-family: "Atma" !important;
129
+ font-weight: 500 !important;
130
+ letter-spacing: 2px !important;
131
+ word-spacing: 5px !important;
132
+ color: white !important;
133
+ display: block !important;
134
+ font-size: 0.99em !important;
135
+ background: #1c9e5c !important;
136
+ padding: 10px !important;
137
+ border: 2px solid black !important;
138
+ border-radius: 0px !important;
139
+ margin-bottom: 1em !important;
140
+ }
141
+ .details_info_block[is-expanded="True"] {
142
+ transition: 1.2s all;
143
+ height: initial !important;
144
+ max-height: initial !important;
145
+ overflow: auto !important;
146
+ margin-bottom: 1em !important;
147
+ }
148
+ .em_warning {
149
+ font-family: "Atma", system-ui !important;
150
+ font-weight: 400 !important;
151
+ font-style: normal !important;
152
+ font-size: 1.3em !important;
153
+ word-spacing: 3.2px !important;
154
+ color: orangered !important;
155
+ }
156
+ .spanify_safetensors_base_model {
157
+ font-family: "Nunito", system-ui, monospace !important;
158
+ font-weight: 600;
159
+ color: green !important;
160
+ }
161
+ .spanify_safetensors_checkpoint_model {
162
+ font-family: "Nunito", system-ui, monospace !important;
163
+ font-weight: 600;
164
+ color: orange !important;
165
+ }
166
+ .spanify_vae_model {
167
+ font-family: "Nunito", system-ui, monospace !important;
168
+ font-weight: 600;
169
+ color: deeppink !important;
170
+ }
171
+ .spanify_lora_checkpoint_model {
172
+ font-family: "Nunito", system-ui, monospace !important;
173
+ font-weight: 600;
174
+ color: #8c627b !important;
175
+ }
176
+ .spanify_other_model {
177
+ font-family: "Nunito", system-ui, monospace !important;
178
+ font-weight: 600;
179
+ color: silver !important;
180
+ }
181
+ .setting_tag_as_mini {
182
+ font-family: "Nunito", serif !important;
183
+ font-optical-sizing: auto !important;
184
+ font-weight: 600 !important;
185
+ font-size: 1.2em !important;
186
+ color: darkgoldenrod !important;
187
+ padding: 4px !important;
188
+ border: 2px dashed gold !important;
189
+ background: white !important;
190
+ word-break: break-word !important;
191
+ display: inline-block !important;
192
+ text-shadow: 1px 1px 1px gold !important;
193
+ }
194
+ li.has_divider_1 {
195
+ list-style: none;
196
+ }
197
+ li.has_divider_1 > div.is_divider_1 {
198
+ display: inline-block;
199
+ width: 100%;
200
+ height: 0.4em;
201
+ background: burlywood;
202
+ }
203
+ .accordion-element,
204
+ .accordion-element button:not(.reset-button):not(.primary),
205
+ .accordion-element button:not(.reset-button):not(.primary) * {
206
+ background: navy !important;
207
+ color: white !important;
208
+ font-family: "Nunito", serif !important;
209
+ font-optical-sizing: auto !important;
210
+ font-weight: 600 !important;
211
+ font-size: 1.2em !important;
212
+ }
213
+ /* the issue with a Gradio gr.Tab() component,
214
+ is that if we specify a CSS class for it through
215
+ the declaration of the well-said component, then,
216
+ it would be useless as it don't permit to targetting
217
+ the TRUE button tab clickable DOMElement..
218
+ so this is a workaround which assume the TAB is at that XPath (at least for Gradio version 5) */
219
+ /* any NOT active Gradio tab : */
220
+ .tab-wrapper .tab-container[role="tablist"] button[role="tab"]:not(.selected) {
221
+ background: red !important;
222
+ color: white !important;
223
+ font-family: "Nunito", serif !important;
224
+ font-optical-sizing: auto !important;
225
+ font-weight: 600 !important;
226
+ font-size: 1.2em !important;
227
+ }
228
+ /* the CURRENT active Gradio tab : */
229
+ .tab-wrapper .tab-container[role="tablist"] button[role="tab"].selected {
230
+ background: green !important;
231
+ color: white !important;
232
+ font-family: "Nunito", serif !important;
233
+ font-optical-sizing: auto !important;
234
+ font-weight: 600 !important;
235
+ font-size: 1.6em !important;
236
+ }
237
+ /* here we only target Gradio button
238
+ that are reset-button
239
+ IN an accordion-element classe */
240
+ .accordion-element button.reset-button {
241
+ background: #803e3e !important;
242
+ color: white !important;
243
+ font-family: "Atma", serif !important;
244
+ font-weight: 700 !important;
245
+ font-style: normal !important;
246
+ }
247
+ /* here we only target Gradio button
248
+ that are primary
249
+ IN an accordion-element classe */
250
+ .accordion-element button.primary {
251
+ background: #5f925e !important;
252
+ font-size: 1.6em !important;
253
+ font-style: oblique !important;
254
+ font-weight: normal !important;
255
+ text-transform: uppercase !important;
256
+ letter-spacing: 1px !important;
257
+ border-bottom-right-radius: 100em !important;
258
+ border-bottom-left-radius: 100em !important;
259
+ font-family: "Atma", serif !important;
260
+ font-weight: 700 !important;
261
+ font-style: normal !important;
262
+ }
263
+ """
264
+
265
+
266
+ help_dict = {
267
+ "hf_username": """
268
+ <article class="setting_tag" use-webfont="wf-nunito-regular">
269
+ (hf_username)
270
+ </article>
271
+ <div class="details_info_block_expanded_override"
272
+ use-webfont="wf-atma-light"
273
+ >
274
+ <em>Your HuggingFace username, no more, no less</em>
275
+ </div>""",
276
+ "hf_write_token_access": """
277
+ <article class="setting_tag" use-webfont="wf-nunito-regular">
278
+ (hf_write_token_access)
279
+ </article>
280
+ <div class="details_info_block"
281
+ is-expanded="False"
282
+ ondblclick="makeExpandable(this);"
283
+ use-webfont="wf-atma-light"
284
+ >
285
+ <em>Your HuggingFace Token with WRITE access</em>
286
+ <br>
287
+ <br>
288
+ - Your Token with WRITE access can be created for free at <a class="linkify_1" target="_blank" href="https://huggingface.co/settings/tokens">https://huggingface.co/settings/tokens</a>.
289
+ <br>
290
+ <br>
291
+ <em class=\"em_warning\">
292
+ please, note once created, note its value somewhere you can retrieve later,
293
+ <br>
294
+ because afterwards it would be no more possible to see its value from the
295
+ <br>
296
+ tokens HuggingFace account page!
297
+ </em>
298
+ </div>""",
299
+ }
300
+
301
+ def help(key):
302
+ with gr.Accordion("Help", open=False) as help:
303
+ gr.HTML(value=help_dict.get(key, ""))
304
+ return help
305
+
306
+
307
+ with gr.Blocks(theme="theNeofr/Syne", fill_width=True, css=css, delete_cache=(60, 3600)) as demo:
308
+ gr.Markdown("# Download SDXL / SD 1.5 / SD 3.5 / FLUX.1 safetensors and convert to HF🤗 Diffusers format and create your repo", elem_classes="title")
309
+ gr.Markdown(f"""
310
+ ### ⚠️IMPORTANT NOTICE⚠️<br>
311
+ It's dangerous to expose your access token or key to others.
312
+ If you do use it, I recommend that you duplicate this space on your own HF account in advance.
313
+ Keys and tokens could be set to **Secrets** (`HF_TOKEN`, `CIVITAI_API_KEY`) if it's placed in your own space.
314
+ It saves you the trouble of typing them in.<br>
315
+ It barely works in the CPU space, but larger files can be converted if duplicated on the more powerful **Zero GPU** space.
316
+ In particular, conversion of FLUX.1 or SD 3.5 is almost impossible in CPU space.
317
+ ### The steps are the following:
318
+ 1. Paste a write-access token from [hf.co/settings/tokens](https://huggingface.co/settings/tokens).
319
+ 1. Input a model download url of the Hugging Face or Civitai or other sites.
320
+ 1. If you want to download a model from Civitai, paste a Civitai API Key.
321
+ 1. Input your HF user ID. e.g. 'yourid'.
322
+ 1. Input your new repo name. If empty, auto-complete. e.g. 'newrepo'.
323
+ 1. Set the parameters. If not sure, just use the defaults.
324
+ 1. Click "Submit".
325
+ 1. Patiently wait until the output changes. It takes approximately 2 to 3 minutes (on SDXL models downloading from HF).
326
+ """)
327
+ with gr.Column():
328
+ dl_url = gr.Textbox(label="URL to download", placeholder="https://huggingface.co/bluepen5805/blue_pencil-XL/blob/main/blue_pencil-XL-v7.0.0.safetensors",
329
+ value=HF_URL, max_lines=1)
330
+ with gr.Group():
331
+ with gr.Row():
332
+ with gr.Column():
333
+ hf_user = gr.Textbox(label="Your HF user ID", placeholder="username", value=HF_USER, max_lines=1)
334
+ help("hf_username")
335
+ with gr.Column():
336
+ hf_repo = gr.Textbox(label="New repo name", placeholder="reponame", info="If empty, auto-complete", value=HF_REPO, max_lines=1)
337
+ with gr.Row(equal_height=True):
338
+ with gr.Column():
339
+ hf_token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
340
+ gr.Markdown("Your token is available at [hf.co/settings/tokens](https://huggingface.co/settings/tokens).", elem_classes="info")
341
+ help("hf_write_token_access")
342
+ with gr.Column():
343
+ civitai_key = gr.Textbox(label="Your Civitai API Key (Optional)", info="If you download model from Civitai...", placeholder="", value="", max_lines=1)
344
+ gr.Markdown("Your Civitai API key is available at [https://civitai.com/user/account](https://civitai.com/user/account).", elem_classes="info")
345
+ with gr.Row():
346
+ is_upload_sf = gr.Checkbox(label="Upload single safetensors file into new repo", value=False)
347
+ is_private = gr.Checkbox(label="Create private repo", value=True)
348
+ gated = gr.Radio(label="Create gated repo", info="Gated repo must be public", choices=["auto", "manual", "False"], value="False")
349
+ with gr.Row():
350
+ is_overwrite = gr.Checkbox(label="Overwrite repo", value=HF_OW)
351
+ is_pr = gr.Checkbox(label="Create PR", value=HF_PR)
352
+ with gr.Tab("SDXL"):
353
+ with gr.Group():
354
+ sdxl_presets = gr.Radio(label="Presets", choices=list(sdxl_preset_dict.keys()), value=list(sdxl_preset_dict.keys())[0])
355
+ sdxl_mtype = gr.Textbox(value="SDXL", visible=False)
356
+ sdxl_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value=DEFAULT_DTYPE)
357
+ with gr.Accordion("Advanced settings", open=False):
358
+ with gr.Row():
359
+ sdxl_vae = gr.Dropdown(label="VAE", choices=sdxl_vaes, value="", allow_custom_value=True)
360
+ sdxl_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=schedulers, value="Euler a")
361
+ sdxl_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
362
+ with gr.Column():
363
+ with gr.Row():
364
+ sdxl_lora1 = gr.Dropdown(label="LoRA1", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320, scale=2)
365
+ sdxl_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
366
+ with gr.Row():
367
+ sdxl_lora2 = gr.Dropdown(label="LoRA2", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320, scale=2)
368
+ sdxl_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
369
+ with gr.Row():
370
+ sdxl_lora3 = gr.Dropdown(label="LoRA3", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320, scale=2)
371
+ sdxl_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
372
+ with gr.Row():
373
+ sdxl_lora4 = gr.Dropdown(label="LoRA4", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320, scale=2)
374
+ sdxl_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
375
+ with gr.Row():
376
+ sdxl_lora5 = gr.Dropdown(label="LoRA5", choices=sdxl_loras, value="", allow_custom_value=True, min_width=320, scale=2)
377
+ sdxl_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
378
+ sdxl_run_button = gr.Button(value="Submit", variant="primary")
379
+ with gr.Tab("SD 1.5"):
380
+ with gr.Group():
381
+ sd15_presets = gr.Radio(label="Presets", choices=list(sd15_preset_dict.keys()), value=list(sd15_preset_dict.keys())[0])
382
+ sd15_mtype = gr.Textbox(value="SD 1.5", visible=False)
383
+ sd15_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value=DEFAULT_DTYPE)
384
+ with gr.Row():
385
+ sd15_ema = gr.Checkbox(label="Extract EMA", value=True, visible=True)
386
+ sd15_isize = gr.Radio(label="Image size", choices=["768", "512"], value="768")
387
+ sd15_sc = gr.Checkbox(label="Safety checker", value=False)
388
+ with gr.Accordion("Advanced settings", open=False):
389
+ with gr.Row():
390
+ sd15_vae = gr.Dropdown(label="VAE", choices=sd15_vaes, value="", allow_custom_value=True)
391
+ sd15_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=schedulers, value="Euler")
392
+ sd15_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
393
+ with gr.Column():
394
+ with gr.Row():
395
+ sd15_lora1 = gr.Dropdown(label="LoRA1", choices=sd15_loras, value="", allow_custom_value=True, min_width=320, scale=2)
396
+ sd15_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
397
+ with gr.Row():
398
+ sd15_lora2 = gr.Dropdown(label="LoRA2", choices=sd15_loras, value="", allow_custom_value=True, min_width=320, scale=2)
399
+ sd15_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
400
+ with gr.Row():
401
+ sd15_lora3 = gr.Dropdown(label="LoRA3", choices=sd15_loras, value="", allow_custom_value=True, min_width=320, scale=2)
402
+ sd15_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
403
+ with gr.Row():
404
+ sd15_lora4 = gr.Dropdown(label="LoRA4", choices=sd15_loras, value="", allow_custom_value=True, min_width=320, scale=2)
405
+ sd15_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
406
+ with gr.Row():
407
+ sd15_lora5 = gr.Dropdown(label="LoRA5", choices=sd15_loras, value="", allow_custom_value=True, min_width=320, scale=2)
408
+ sd15_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
409
+ sd15_run_button = gr.Button(value="Submit", variant="primary")
410
+ with gr.Tab("FLUX.1"):
411
+ with gr.Group():
412
+ flux_presets = gr.Radio(label="Presets", choices=list(flux_preset_dict.keys()), value=list(flux_preset_dict.keys())[0])
413
+ flux_mtype = gr.Textbox(value="FLUX", visible=False)
414
+ flux_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value="bf16")
415
+ flux_base_repo = gr.Dropdown(label="Base repo ID", choices=FLUX_BASE_REPOS, value=FLUX_BASE_REPOS[0], allow_custom_value=True, visible=True)
416
+ with gr.Accordion("Advanced settings", open=False):
417
+ with gr.Row():
418
+ flux_vae = gr.Dropdown(label="VAE", choices=flux_vaes, value="", allow_custom_value=True)
419
+ flux_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=[""], value="", visible=False)
420
+ with gr.Row():
421
+ flux_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
422
+ flux_t5 = gr.Dropdown(label="T5", choices=t5s, value="", allow_custom_value=True)
423
+ with gr.Column():
424
+ with gr.Row():
425
+ flux_lora1 = gr.Dropdown(label="LoRA1", choices=flux_loras, value="", allow_custom_value=True, min_width=320, scale=2)
426
+ flux_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
427
+ with gr.Row():
428
+ flux_lora2 = gr.Dropdown(label="LoRA2", choices=flux_loras, value="", allow_custom_value=True, min_width=320, scale=2)
429
+ flux_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
430
+ with gr.Row():
431
+ flux_lora3 = gr.Dropdown(label="LoRA3", choices=flux_loras, value="", allow_custom_value=True, min_width=320, scale=2)
432
+ flux_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
433
+ with gr.Row():
434
+ flux_lora4 = gr.Dropdown(label="LoRA4", choices=flux_loras, value="", allow_custom_value=True, min_width=320, scale=2)
435
+ flux_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
436
+ with gr.Row():
437
+ flux_lora5 = gr.Dropdown(label="LoRA5", choices=flux_loras, value="", allow_custom_value=True, min_width=320, scale=2)
438
+ flux_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
439
+ flux_run_button = gr.Button(value="Submit", variant="primary")
440
+ with gr.Tab("SD 3.5"):
441
+ with gr.Group():
442
+ sd35_presets = gr.Radio(label="Presets", choices=list(sd35_preset_dict.keys()), value=list(sd35_preset_dict.keys())[0])
443
+ sd35_mtype = gr.Textbox(value="SD 3.5", visible=False)
444
+ sd35_dtype = gr.Radio(label="Output data type", choices=get_dtypes(), value="bf16")
445
+ sd35_base_repo = gr.Dropdown(label="Base repo ID", choices=SD35_BASE_REPOS, value=SD35_BASE_REPOS[0], allow_custom_value=True, visible=True)
446
+ with gr.Accordion("Advanced settings", open=False):
447
+ with gr.Row():
448
+ sd35_vae = gr.Dropdown(label="VAE", choices=sd35_vaes, value="", allow_custom_value=True)
449
+ sd35_scheduler = gr.Dropdown(label="Scheduler (Sampler)", choices=[""], value="", visible=False)
450
+ with gr.Row():
451
+ sd35_clip = gr.Dropdown(label="CLIP", choices=clips, value="", allow_custom_value=True)
452
+ sd35_t5 = gr.Dropdown(label="T5", choices=t5s, value="", allow_custom_value=True)
453
+ with gr.Column():
454
+ with gr.Row():
455
+ sd35_lora1 = gr.Dropdown(label="LoRA1", choices=sd35_loras, value="", allow_custom_value=True, min_width=320, scale=2)
456
+ sd35_lora1s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA1 weight scale")
457
+ with gr.Row():
458
+ sd35_lora2 = gr.Dropdown(label="LoRA2", choices=sd35_loras, value="", allow_custom_value=True, min_width=320, scale=2)
459
+ sd35_lora2s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA2 weight scale")
460
+ with gr.Row():
461
+ sd35_lora3 = gr.Dropdown(label="LoRA3", choices=sd35_loras, value="", allow_custom_value=True, min_width=320, scale=2)
462
+ sd35_lora3s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA3 weight scale")
463
+ with gr.Row():
464
+ sd35_lora4 = gr.Dropdown(label="LoRA4", choices=sd35_loras, value="", allow_custom_value=True, min_width=320, scale=2)
465
+ sd35_lora4s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA4 weight scale")
466
+ with gr.Row():
467
+ sd35_lora5 = gr.Dropdown(label="LoRA5", choices=sd35_loras, value="", allow_custom_value=True, min_width=320, scale=2)
468
+ sd35_lora5s = gr.Slider(minimum=-2, maximum=2, step=0.01, value=1.00, label="LoRA5 weight scale")
469
+ sd35_run_button = gr.Button(value="Submit", variant="primary")
470
+ adv_args = gr.Textbox(label="Advanced arguments", value="", visible=False)
471
+ with gr.Group():
472
+ repo_urls = gr.CheckboxGroup(visible=False, choices=[], value=[])
473
+ output_md = gr.Markdown(label="Output", value="<br><br>", elem_classes="result")
474
+ clear_button = gr.Button(value="Clear Output", variant="secondary")
475
+ gr.DuplicateButton(value="Duplicate Space")
476
+
477
+ gr.Markdown("This webui was redesigned with ❤ by [theNeofr](https://huggingface.co/theNeofr)")
478
+ gr.on(
479
+ triggers=[sdxl_run_button.click],
480
+ fn=convert_url_to_diffusers_repo,
481
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, gated, is_overwrite, is_pr, is_upload_sf, repo_urls,
482
+ sdxl_dtype, sdxl_vae, sdxl_clip, flux_t5, sdxl_scheduler, sd15_ema, sd15_isize, sd15_sc, flux_base_repo, sdxl_mtype,
483
+ sdxl_lora1, sdxl_lora1s, sdxl_lora2, sdxl_lora2s, sdxl_lora3, sdxl_lora3s, sdxl_lora4, sdxl_lora4s, sdxl_lora5, sdxl_lora5s, adv_args],
484
+ outputs=[repo_urls, output_md],
485
+ )
486
+ sdxl_presets.change(
487
+ fn=sdxl_set_presets,
488
+ inputs=[sdxl_presets],
489
+ outputs=[sdxl_dtype, sdxl_vae, sdxl_scheduler, sdxl_lora1, sdxl_lora1s, sdxl_lora2, sdxl_lora2s, sdxl_lora3, sdxl_lora3s,
490
+ sdxl_lora4, sdxl_lora4s, sdxl_lora5, sdxl_lora5s],
491
+ queue=False,
492
+ )
493
+ gr.on(
494
+ triggers=[sd15_run_button.click],
495
+ fn=convert_url_to_diffusers_repo,
496
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, gated, is_overwrite, is_pr, is_upload_sf, repo_urls,
497
+ sd15_dtype, sd15_vae, sd15_clip, flux_t5, sd15_scheduler, sd15_ema, sd15_isize, sd15_sc, flux_base_repo, sd15_mtype,
498
+ sd15_lora1, sd15_lora1s, sd15_lora2, sd15_lora2s, sd15_lora3, sd15_lora3s, sd15_lora4, sd15_lora4s, sd15_lora5, sd15_lora5s, adv_args],
499
+ outputs=[repo_urls, output_md],
500
+ )
501
+ sd15_presets.change(
502
+ fn=sd15_set_presets,
503
+ inputs=[sd15_presets],
504
+ outputs=[sd15_dtype, sd15_vae, sd15_scheduler, sd15_lora1, sd15_lora1s, sd15_lora2, sd15_lora2s, sd15_lora3, sd15_lora3s,
505
+ sd15_lora4, sd15_lora4s, sd15_lora5, sd15_lora5s, sd15_ema],
506
+ queue=False,
507
+ )
508
+ gr.on(
509
+ triggers=[flux_run_button.click],
510
+ fn=convert_url_to_diffusers_repo,
511
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, gated, is_overwrite, is_pr, is_upload_sf, repo_urls,
512
+ flux_dtype, flux_vae, flux_clip, flux_t5, flux_scheduler, sd15_ema, sd15_isize, sd15_sc, flux_base_repo, flux_mtype,
513
+ flux_lora1, flux_lora1s, flux_lora2, flux_lora2s, flux_lora3, flux_lora3s, flux_lora4, flux_lora4s, flux_lora5, flux_lora5s, adv_args],
514
+ outputs=[repo_urls, output_md],
515
+ )
516
+ flux_presets.change(
517
+ fn=flux_set_presets,
518
+ inputs=[flux_presets],
519
+ outputs=[flux_dtype, flux_vae, flux_scheduler, flux_lora1, flux_lora1s, flux_lora2, flux_lora2s, flux_lora3, flux_lora3s,
520
+ flux_lora4, flux_lora4s, flux_lora5, flux_lora5s, flux_base_repo],
521
+ queue=False,
522
+ )
523
+ gr.on(
524
+ triggers=[sd35_run_button.click],
525
+ fn=convert_url_to_diffusers_repo,
526
+ inputs=[dl_url, hf_user, hf_repo, hf_token, civitai_key, is_private, gated, is_overwrite, is_pr, is_upload_sf, repo_urls,
527
+ sd35_dtype, sd35_vae, sd35_clip, sd35_t5, sd35_scheduler, sd15_ema, sd15_isize, sd15_sc, sd35_base_repo, sd35_mtype,
528
+ sd35_lora1, sd35_lora1s, sd35_lora2, sd35_lora2s, sd35_lora3, sd35_lora3s, sd35_lora4, sd35_lora4s, sd35_lora5, sd35_lora5s, adv_args],
529
+ outputs=[repo_urls, output_md],
530
+ )
531
+ sd35_presets.change(
532
+ fn=sd35_set_presets,
533
+ inputs=[sd35_presets],
534
+ outputs=[sd35_dtype, sd35_vae, sd35_scheduler, sd35_lora1, sd35_lora1s, sd35_lora2, sd35_lora2s, sd35_lora3, sd35_lora3s,
535
+ sd35_lora4, sd35_lora4s, sd35_lora5, sd35_lora5s, sd35_base_repo],
536
+ queue=False,
537
+ )
538
+ clear_button.click(lambda: ([], "<br><br>"), None, [repo_urls, output_md], queue=False, show_api=False)
539
+
540
+ demo.queue()
541
+ demo.launch(ssr_mode=False)
convert_url_to_diffusers_multi_gr.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import argparse
4
+ from pathlib import Path
5
+ import os
6
+ import torch
7
+ from diffusers import (DiffusionPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler, StableDiffusionXLPipeline, StableDiffusionPipeline,
8
+ FluxPipeline, FluxTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline)
9
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPFeatureExtractor, AutoTokenizer, T5EncoderModel, BitsAndBytesConfig as TFBitsAndBytesConfig
10
+ from huggingface_hub import save_torch_state_dict, snapshot_download
11
+ from diffusers.loaders.single_file_utils import (convert_flux_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers,
12
+ convert_sd3_t5_checkpoint_to_diffusers)
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
14
+ import safetensors.torch
15
+ import gradio as gr
16
+ import shutil
17
+ import gc
18
+ import tempfile
19
+ # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
20
+ from utils import (get_token, set_token, is_repo_exists, is_repo_name, get_download_file, upload_repo, gate_repo)
21
+ from sdutils import (SCHEDULER_CONFIG_MAP, get_scheduler_config, fuse_loras, DTYPE_DEFAULT, get_dtype, get_dtypes, get_model_type_from_key, get_process_dtype)
22
+
23
+
24
+ @spaces.GPU
25
+ def fake_gpu():
26
+ pass
27
+
28
+
29
+ try:
30
+ from diffusers import BitsAndBytesConfig
31
+ is_nf4 = True
32
+ except Exception:
33
+ is_nf4 = False
34
+
35
+
36
+ FLUX_BASE_REPOS = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell", "John6666/flux1-dev-fp8-flux", "John6666/flux1-schnell-fp8-flux"]
37
+ FLUX_T5_URL = "https://huggingface.co/camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors"
38
+ SD35_BASE_REPOS = ["adamo1139/stable-diffusion-3.5-large-ungated", "adamo1139/stable-diffusion-3.5-large-turbo-ungated"]
39
+ SD35_T5_URL = "https://huggingface.co/adamo1139/stable-diffusion-3.5-large-turbo-ungated/blob/main/text_encoders/t5xxl_fp8_e4m3fn.safetensors"
40
+ TEMP_DIR = tempfile.mkdtemp()
41
+ IS_ZERO = os.environ.get("SPACES_ZERO_GPU") is not None
42
+ IS_CUDA = torch.cuda.is_available()
43
+
44
+
45
+ def safe_clean(path: str):
46
+ try:
47
+ if Path(path).exists():
48
+ if Path(path).is_dir(): shutil.rmtree(str(Path(path)))
49
+ else: Path(path).unlink()
50
+ print(f"Deleted: {path}")
51
+ else: print(f"File not found: {path}")
52
+ except Exception as e:
53
+ print(f"Failed to delete: {path} {e}")
54
+
55
+
56
+ def save_readme_md(dir, url):
57
+ orig_url = ""
58
+ orig_name = ""
59
+ if is_repo_name(url):
60
+ orig_name = url
61
+ orig_url = f"https://huggingface.co/{url}/"
62
+ elif "http" in url:
63
+ orig_name = url
64
+ orig_url = url
65
+ if orig_name and orig_url:
66
+ md = f"""---
67
+ license: other
68
+ language:
69
+ - en
70
+ library_name: diffusers
71
+ pipeline_tag: text-to-image
72
+ tags:
73
+ - text-to-image
74
+ ---
75
+ Converted from [{orig_name}]({orig_url}).
76
+ """
77
+ else:
78
+ md = f"""---
79
+ license: other
80
+ language:
81
+ - en
82
+ library_name: diffusers
83
+ pipeline_tag: text-to-image
84
+ tags:
85
+ - text-to-image
86
+ ---
87
+ """
88
+ path = str(Path(dir, "README.md"))
89
+ with open(path, mode='w', encoding="utf-8") as f:
90
+ f.write(md)
91
+
92
+
93
+ def save_module(model, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)): # doesn't work
94
+ if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors"
95
+ else: pattern = "model{suffix}.safetensors"
96
+ if name in ["transformer", "unet"]: size = "10GB"
97
+ else: size = "5GB"
98
+ path = str(Path(f"{dir.removesuffix('/')}/{name}"))
99
+ os.makedirs(path, exist_ok=True)
100
+ progress(0, desc=f"Saving {name} to {dir}...")
101
+ print(f"Saving {name} to {dir}...")
102
+ model.to("cpu")
103
+ sd = dict(model.state_dict())
104
+ new_sd = {}
105
+ for key in list(sd.keys()):
106
+ q = sd.pop(key)
107
+ if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn)
108
+ else: new_sd[key] = q
109
+ del sd
110
+ gc.collect()
111
+ save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size)
112
+ del new_sd
113
+ gc.collect()
114
+
115
+
116
+ def save_module_sd(sd: dict, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)):
117
+ if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors"
118
+ else: pattern = "model{suffix}.safetensors"
119
+ if name in ["transformer", "unet"]: size = "10GB"
120
+ else: size = "5GB"
121
+ path = str(Path(f"{dir.removesuffix('/')}/{name}"))
122
+ os.makedirs(path, exist_ok=True)
123
+ progress(0, desc=f"Saving state_dict of {name} to {dir}...")
124
+ print(f"Saving state_dict of {name} to {dir}...")
125
+ new_sd = {}
126
+ for key in list(sd.keys()):
127
+ q = sd.pop(key).to("cpu")
128
+ if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn)
129
+ else: new_sd[key] = q
130
+ save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size)
131
+ del new_sd
132
+ gc.collect()
133
+
134
+
135
+ def convert_flux_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)):
136
+ temp_dir = TEMP_DIR
137
+ down_dir = str(Path(f"{TEMP_DIR}/down"))
138
+ os.makedirs(down_dir, exist_ok=True)
139
+ hf_token = get_token()
140
+ progress(0.25, desc=f"Loading {new_file}...")
141
+ orig_sd = safetensors.torch.load_file(new_file)
142
+ progress(0.3, desc=f"Converting {new_file}...")
143
+ conv_sd = convert_flux_transformer_checkpoint_to_diffusers(orig_sd)
144
+ del orig_sd
145
+ gc.collect()
146
+ progress(0.35, desc=f"Saving {new_file}...")
147
+ save_module_sd(conv_sd, "transformer", new_dir, dtype)
148
+ del conv_sd
149
+ gc.collect()
150
+ progress(0.5, desc=f"Loading text_encoder_2 from {FLUX_T5_URL}...")
151
+ t5_file = get_download_file(temp_dir, FLUX_T5_URL, civitai_key)
152
+ if not t5_file: raise Exception(f"Safetensors file not found: {FLUX_T5_URL}")
153
+ t5_sd = safetensors.torch.load_file(t5_file)
154
+ safe_clean(t5_file)
155
+ save_module_sd(t5_sd, "text_encoder_2", new_dir, dtype)
156
+ del t5_sd
157
+ gc.collect()
158
+ progress(0.6, desc=f"Loading other components from {base_repo}...")
159
+ pipe = FluxPipeline.from_pretrained(base_repo, transformer=None, text_encoder_2=None, use_safetensors=True, **kwargs,
160
+ torch_dtype=torch.bfloat16, token=hf_token)
161
+ pipe.save_pretrained(new_dir)
162
+ progress(0.75, desc=f"Loading nontensor files from {base_repo}...")
163
+ snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True,
164
+ ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
165
+ shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
166
+ safe_clean(down_dir)
167
+
168
+
169
+ def convert_sd35_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)):
170
+ temp_dir = TEMP_DIR
171
+ down_dir = str(Path(f"{TEMP_DIR}/down"))
172
+ os.makedirs(down_dir, exist_ok=True)
173
+ hf_token = get_token()
174
+ progress(0.25, desc=f"Loading {new_file}...")
175
+ orig_sd = safetensors.torch.load_file(new_file)
176
+ progress(0.3, desc=f"Converting {new_file}...")
177
+ conv_sd = convert_sd3_transformer_checkpoint_to_diffusers(orig_sd)
178
+ del orig_sd
179
+ gc.collect()
180
+ progress(0.35, desc=f"Saving {new_file}...")
181
+ save_module_sd(conv_sd, "transformer", new_dir, dtype)
182
+ del conv_sd
183
+ gc.collect()
184
+ progress(0.5, desc=f"Loading text_encoder_3 from {SD35_T5_URL}...")
185
+ t5_file = get_download_file(temp_dir, SD35_T5_URL, civitai_key)
186
+ if not t5_file: raise Exception(f"Safetensors file not found: {SD35_T5_URL}")
187
+ t5_sd = safetensors.torch.load_file(t5_file)
188
+ safe_clean(t5_file)
189
+ conv_t5_sd = convert_sd3_t5_checkpoint_to_diffusers(t5_sd)
190
+ del t5_sd
191
+ gc.collect()
192
+ save_module_sd(conv_t5_sd, "text_encoder_3", new_dir, dtype)
193
+ del conv_t5_sd
194
+ gc.collect()
195
+ progress(0.6, desc=f"Loading other components from {base_repo}...")
196
+ pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=None, text_encoder_3=None, use_safetensors=True, **kwargs,
197
+ torch_dtype=torch.bfloat16, token=hf_token)
198
+ pipe.save_pretrained(new_dir)
199
+ progress(0.75, desc=f"Loading nontensor files from {base_repo}...")
200
+ snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True,
201
+ ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
202
+ shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
203
+ safe_clean(down_dir)
204
+
205
+
206
+ #@spaces.GPU(duration=60)
207
+ def load_and_save_pipeline(pipe, model_type: str, url: str, new_file: str, new_dir: str, dtype: str,
208
+ scheduler: str, ema: bool, image_size: str, is_safety_checker: bool, base_repo: str, civitai_key: str, lora_dict: dict,
209
+ my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder,
210
+ kwargs: dict, dkwargs: dict, progress=gr.Progress(track_tqdm=True)):
211
+ try:
212
+ hf_token = get_token()
213
+ temp_dir = TEMP_DIR
214
+ qkwargs = {}
215
+ tfqkwargs = {}
216
+ if is_nf4:
217
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
218
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
219
+ nf4_config_tf = TFBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
220
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
221
+ else:
222
+ nf4_config = None
223
+ nf4_config_tf = None
224
+ if dtype == "NF4" and nf4_config is not None and nf4_config_tf is not None:
225
+ qkwargs["quantization_config"] = nf4_config
226
+ tfqkwargs["quantization_config"] = nf4_config_tf
227
+
228
+ #print(f"model_type:{model_type}, dtype:{dtype}, scheduler:{scheduler}, ema:{ema}, base_repo:{base_repo}")
229
+ #print("lora_dict:", lora_dict, "kwargs:", kwargs, "dkwargs:", dkwargs)
230
+
231
+ #t5 = None
232
+
233
+ if model_type == "SDXL":
234
+ if is_repo_name(url): pipe = StableDiffusionXLPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
235
+ else: pipe = StableDiffusionXLPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs)
236
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
237
+ sconf = get_scheduler_config(scheduler)
238
+ pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
239
+ pipe.save_pretrained(new_dir)
240
+ elif model_type == "SD 1.5":
241
+ if is_safety_checker:
242
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
243
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
244
+ kwargs["requires_safety_checker"] = True
245
+ kwargs["safety_checker"] = safety_checker
246
+ kwargs["feature_extractor"] = feature_extractor
247
+ else: kwargs["requires_safety_checker"] = False
248
+ if is_repo_name(url): pipe = StableDiffusionPipeline.from_pretrained(url, extract_ema=ema, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
249
+ else: pipe = StableDiffusionPipeline.from_single_file(new_file, extract_ema=ema, use_safetensors=True, **kwargs, **dkwargs)
250
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
251
+ sconf = get_scheduler_config(scheduler)
252
+ pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
253
+ if image_size != "512": pipe.vae = AutoencoderKL.from_config(pipe.vae.config, sample_size=int(image_size))
254
+ pipe.save_pretrained(new_dir)
255
+ elif model_type == "FLUX":
256
+ if dtype != "fp8":
257
+ if is_repo_name(url):
258
+ transformer = FluxTransformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
259
+ #if my_t5_encoder is None:
260
+ # t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs)
261
+ # kwargs["text_encoder_2"] = t5
262
+ pipe = FluxPipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
263
+ else:
264
+ transformer = FluxTransformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
265
+ #if my_t5_encoder is None:
266
+ # t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs)
267
+ # kwargs["text_encoder_2"] = t5
268
+ pipe = FluxPipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
269
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
270
+ pipe.save_pretrained(new_dir)
271
+ elif not is_repo_name(url): convert_flux_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs)
272
+ elif model_type == "SD 3.5":
273
+ if dtype != "fp8":
274
+ if is_repo_name(url):
275
+ transformer = SD3Transformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
276
+ #if my_t5_encoder is None:
277
+ # t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs)
278
+ # kwargs["text_encoder_3"] = t5
279
+ pipe = StableDiffusion3Pipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
280
+ else:
281
+ transformer = SD3Transformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
282
+ #if my_t5_encoder is None:
283
+ # t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs)
284
+ # kwargs["text_encoder_3"] = t5
285
+ pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
286
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
287
+ pipe.save_pretrained(new_dir)
288
+ elif not is_repo_name(url): convert_sd35_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs)
289
+ else: # unknown model type
290
+ if is_repo_name(url): pipe = DiffusionPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
291
+ else: pipe = DiffusionPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs)
292
+ pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
293
+ pipe.save_pretrained(new_dir)
294
+ except Exception as e:
295
+ print(f"Failed to load pipeline. {e}")
296
+ raise Exception("Failed to load pipeline.") from e
297
+ finally:
298
+ return pipe
299
+
300
+
301
+ def convert_url_to_diffusers(url: str, civitai_key: str="", is_upload_sf: bool=False, dtype: str="fp16", vae: str="", clip: str="", t5: str="",
302
+ scheduler: str="Euler a", ema: bool=True, image_size: str="768", safety_checker: bool=False,
303
+ base_repo: str="", mtype: str="", lora_dict: dict={}, is_local: bool=True, progress=gr.Progress(track_tqdm=True)):
304
+ try:
305
+ hf_token = get_token()
306
+ progress(0, desc="Start converting...")
307
+ temp_dir = TEMP_DIR
308
+
309
+ if is_repo_name(url) and is_repo_exists(url):
310
+ new_file = url
311
+ model_type = mtype
312
+ else:
313
+ new_file = get_download_file(temp_dir, url, civitai_key)
314
+ if not new_file: raise Exception(f"Safetensors file not found: {url}")
315
+ model_type = get_model_type_from_key(new_file)
316
+ new_dir = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") #
317
+
318
+ kwargs = {}
319
+ dkwargs = {}
320
+ if dtype != DTYPE_DEFAULT: dkwargs["torch_dtype"] = get_process_dtype(dtype, model_type)
321
+ pipe = None
322
+
323
+ print(f"Model type: {model_type} / VAE: {vae} / CLIP: {clip} / T5: {t5} / Scheduler: {scheduler} / dtype: {dtype} / EMA: {ema} / Base repo: {base_repo} / LoRAs: {lora_dict}")
324
+
325
+ my_vae = None
326
+ if vae:
327
+ progress(0, desc=f"Loading VAE: {vae}...")
328
+ if is_repo_name(vae): my_vae = AutoencoderKL.from_pretrained(vae, **dkwargs, token=hf_token)
329
+ else:
330
+ new_vae_file = get_download_file(temp_dir, vae, civitai_key)
331
+ my_vae = AutoencoderKL.from_single_file(new_vae_file, **dkwargs) if new_vae_file else None
332
+ safe_clean(new_vae_file)
333
+ if my_vae: kwargs["vae"] = my_vae
334
+
335
+ my_clip_tokenizer = None
336
+ my_clip_encoder = None
337
+ if clip:
338
+ progress(0, desc=f"Loading CLIP: {clip}...")
339
+ if is_repo_name(clip):
340
+ my_clip_tokenizer = CLIPTokenizer.from_pretrained(clip, token=hf_token)
341
+ if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_pretrained(clip, **dkwargs, token=hf_token)
342
+ else: my_clip_encoder = CLIPTextModel.from_pretrained(clip, **dkwargs, token=hf_token)
343
+ else:
344
+ new_clip_file = get_download_file(temp_dir, clip, civitai_key)
345
+ if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None
346
+ else: my_clip_encoder = CLIPTextModel.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None
347
+ safe_clean(new_clip_file)
348
+ if model_type == "SD 3.5":
349
+ if my_clip_tokenizer:
350
+ kwargs["tokenizer"] = my_clip_tokenizer
351
+ kwargs["tokenizer_2"] = my_clip_tokenizer
352
+ if my_clip_encoder:
353
+ kwargs["text_encoder"] = my_clip_encoder
354
+ kwargs["text_encoder_2"] = my_clip_encoder
355
+ else:
356
+ if my_clip_tokenizer: kwargs["tokenizer"] = my_clip_tokenizer
357
+ if my_clip_encoder: kwargs["text_encoder"] = my_clip_encoder
358
+
359
+ my_t5_tokenizer = None
360
+ my_t5_encoder = None
361
+ if t5:
362
+ progress(0, desc=f"Loading T5: {t5}...")
363
+ if is_repo_name(t5):
364
+ my_t5_tokenizer = AutoTokenizer.from_pretrained(t5, token=hf_token)
365
+ my_t5_encoder = T5EncoderModel.from_pretrained(t5, **dkwargs, token=hf_token)
366
+ else:
367
+ new_t5_file = get_download_file(temp_dir, t5, civitai_key)
368
+ my_t5_encoder = T5EncoderModel.from_single_file(new_t5_file, **dkwargs) if new_t5_file else None
369
+ safe_clean(new_t5_file)
370
+ if model_type == "SD 3.5":
371
+ if my_t5_tokenizer: kwargs["tokenizer_3"] = my_t5_tokenizer
372
+ if my_t5_encoder: kwargs["text_encoder_3"] = my_t5_encoder
373
+ else:
374
+ if my_t5_tokenizer: kwargs["tokenizer_2"] = my_t5_tokenizer
375
+ if my_t5_encoder: kwargs["text_encoder_2"] = my_t5_encoder
376
+
377
+ pipe = load_and_save_pipeline(pipe, model_type, url, new_file, new_dir, dtype, scheduler, ema, image_size, safety_checker, base_repo, civitai_key, lora_dict,
378
+ my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder, kwargs, dkwargs)
379
+
380
+ if Path(new_dir).exists(): save_readme_md(new_dir, url)
381
+
382
+ if not is_local:
383
+ if not is_repo_name(new_file) and is_upload_sf: shutil.move(str(Path(new_file).resolve()), str(Path(new_dir, Path(new_file).name).resolve()))
384
+ else: safe_clean(new_file)
385
+
386
+ progress(1, desc="Converted.")
387
+ return new_dir
388
+ except Exception as e:
389
+ print(f"Failed to convert. {e}")
390
+ raise Exception("Failed to convert.") from e
391
+ finally:
392
+ del pipe
393
+ torch.cuda.empty_cache()
394
+ gc.collect()
395
+
396
+
397
+ def convert_url_to_diffusers_repo(dl_url: str, hf_user: str, hf_repo: str, hf_token: str, civitai_key="", is_private: bool=True,
398
+ gated: str="False", is_overwrite: bool=False, is_pr: bool=False,
399
+ is_upload_sf: bool=False, urls: list=[], dtype: str="fp16", vae: str="", clip: str="", t5: str="", scheduler: str="Euler a",
400
+ ema: bool=True, image_size: str="768", safety_checker: bool=False,
401
+ base_repo: str="", mtype: str="", lora1: str="", lora1s=1.0, lora2: str="", lora2s=1.0, lora3: str="", lora3s=1.0,
402
+ lora4: str="", lora4s=1.0, lora5: str="", lora5s=1.0, args: str="", progress=gr.Progress(track_tqdm=True)):
403
+ try:
404
+ is_local = False
405
+ if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key
406
+ if not hf_token and os.environ.get("HF_TOKEN"): hf_token = os.environ.get("HF_TOKEN") # default HF write token
407
+ if not hf_user: raise gr.Error(f"Invalid user name: {hf_user}")
408
+ if gated != "False" and is_private: raise gr.Error(f"Gated repo must be public")
409
+ set_token(hf_token)
410
+ lora_dict = {lora1: lora1s, lora2: lora2s, lora3: lora3s, lora4: lora4s, lora5: lora5s}
411
+ new_path = convert_url_to_diffusers(dl_url, civitai_key, is_upload_sf, dtype, vae, clip, t5, scheduler, ema, image_size, safety_checker, base_repo, mtype, lora_dict, is_local)
412
+ if not new_path: return ""
413
+ new_repo_id = f"{hf_user}/{Path(new_path).stem}"
414
+ if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
415
+ if not is_repo_name(new_repo_id): raise gr.Error(f"Invalid repo name: {new_repo_id}")
416
+ if not is_overwrite and is_repo_exists(new_repo_id) and not is_pr: raise gr.Error(f"Repo already exists: {new_repo_id}")
417
+ repo_url = upload_repo(new_repo_id, new_path, is_private, is_pr)
418
+ gate_repo(new_repo_id, gated)
419
+ safe_clean(new_path)
420
+ if not urls: urls = []
421
+ urls.append(repo_url)
422
+ md = "### Your new repo:\n"
423
+ for u in urls:
424
+ md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
425
+ return gr.update(value=urls, choices=urls), gr.update(value=md)
426
+ except Exception as e:
427
+ print(f"Error occured. {e}")
428
+ raise gr.Error(f"Error occured. {e}")
429
+
430
+
431
+ if __name__ == "__main__":
432
+ parser = argparse.ArgumentParser()
433
+
434
+ parser.add_argument("--url", type=str, required=True, help="URL of the model to convert.")
435
+ parser.add_argument("--dtype", default="fp16", type=str, choices=get_dtypes(), help='Output data type. (Default: "fp16")')
436
+ parser.add_argument("--scheduler", default="Euler a", type=str, choices=list(SCHEDULER_CONFIG_MAP.keys()), required=False, help="Scheduler name to use.")
437
+ parser.add_argument("--vae", default="", type=str, required=False, help="URL or Repo ID of the VAE to use.")
438
+ parser.add_argument("--clip", default="", type=str, required=False, help="URL or Repo ID of the CLIP to use.")
439
+ parser.add_argument("--t5", default="", type=str, required=False, help="URL or Repo ID of the T5 to use.")
440
+ parser.add_argument("--base", default="", type=str, required=False, help="Repo ID of the base repo.")
441
+ parser.add_argument("--nonema", action="store_true", default=False, help="Don't extract EMA (for SD 1.5).")
442
+ parser.add_argument("--civitai_key", default="", type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).")
443
+ parser.add_argument("--lora1", default="", type=str, required=False, help="URL of the LoRA to use.")
444
+ parser.add_argument("--lora1s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora1.")
445
+ parser.add_argument("--lora2", default="", type=str, required=False, help="URL of the LoRA to use.")
446
+ parser.add_argument("--lora2s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora2.")
447
+ parser.add_argument("--lora3", default="", type=str, required=False, help="URL of the LoRA to use.")
448
+ parser.add_argument("--lora3s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora3.")
449
+ parser.add_argument("--lora4", default="", type=str, required=False, help="URL of the LoRA to use.")
450
+ parser.add_argument("--lora4s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora4.")
451
+ parser.add_argument("--lora5", default="", type=str, required=False, help="URL of the LoRA to use.")
452
+ parser.add_argument("--lora5s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora5.")
453
+ parser.add_argument("--loras", default="", type=str, required=False, help="Folder of the LoRA to use.")
454
+
455
+ args = parser.parse_args()
456
+ assert args.url is not None, "Must provide a URL!"
457
+
458
+ is_local = True
459
+ lora_dict = {args.lora1: args.lora1s, args.lora2: args.lora2s, args.lora3: args.lora3s, args.lora4: args.lora4s, args.lora5: args.lora5s}
460
+ if args.loras and Path(args.loras).exists():
461
+ for p in Path(args.loras).glob('**/*.safetensors'):
462
+ lora_dict[str(p)] = 1.0
463
+ ema = not args.nonema
464
+ mtype = "SDXL"
465
+
466
+ convert_url_to_diffusers(args.url, args.civitai_key, args.dtype, args.vae, args.clip, args.t5, args.scheduler, ema, args.base, mtype, lora_dict, is_local)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git-lfs aria2
presets.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sdutils import get_dtypes, SCHEDULER_CONFIG_MAP
2
+ import gradio as gr
3
+
4
+
5
+ DEFAULT_DTYPE = get_dtypes()[0]
6
+ schedulers = list(SCHEDULER_CONFIG_MAP.keys())
7
+
8
+
9
+ clips = [
10
+ "",
11
+ "openai/clip-vit-large-patch14",
12
+ ]
13
+
14
+
15
+ t5s = [
16
+ "",
17
+ "https://huggingface.co/camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors",
18
+ ]
19
+
20
+
21
+ sdxl_vaes = [
22
+ "",
23
+ "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
24
+ "https://huggingface.co/nubby/blessed-sdxl-vae-fp16-fix/blob/main/sdxl_vae-fp16fix-blessed.safetensors",
25
+ "https://huggingface.co/John6666/safetensors_converting_test/blob/main/xlVAEC_e7.safetensors",
26
+ "https://huggingface.co/John6666/safetensors_converting_test/blob/main/xlVAEC_f1.safetensors",
27
+ ]
28
+
29
+
30
+ sdxl_loras = [
31
+ "",
32
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep_LoRA/blob/main/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors",
33
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_2step_converted.safetensors",
34
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_4step_converted.safetensors",
35
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_smallcfg_8step_converted.safetensors",
36
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_normalcfg_8step_converted.safetensors",
37
+ "https://huggingface.co/wangfuyun/PCM_Weights/blob/main/sdxl/pcm_sdxl_normalcfg_16step_converted.safetensors",
38
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-1step-lora.safetensors",
39
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-2steps-lora.safetensors",
40
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-4steps-lora.safetensors",
41
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-8steps-CFG-lora.safetensors",
42
+ "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-12steps-CFG-lora.safetensors",
43
+ "https://huggingface.co/latent-consistency/lcm-lora-sdxl/blob/main/pytorch_lora_weights.safetensors",
44
+ ]
45
+
46
+
47
+ sdxl_preset_items = ["dtype", "vae", "scheduler", "lora1", "lora1s", "lora2", "lora2s", "lora3", "lora3s", "lora4", "lora4s", "lora5", "lora5s"]
48
+ sdxl_preset_dict = {
49
+ "Default": [DEFAULT_DTYPE, "", "Euler a", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0],
50
+ "Bake in standard VAE": [DEFAULT_DTYPE, "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
51
+ "Euler a", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0],
52
+ "Hyper-SDXL / SPO": [DEFAULT_DTYPE, "https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl.vae.safetensors",
53
+ "TCD", "https://huggingface.co/ByteDance/Hyper-SD/blob/main/Hyper-SDXL-8steps-CFG-lora.safetensors", 1.0,
54
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SDXL_4k-p_10ep_LoRA/blob/main/spo_sdxl_10ep_4k-data_lora_diffusers.safetensors",
55
+ 1.0, "", 1.0, "", 1.0, "", 1.0],
56
+ }
57
+
58
+
59
+ def sdxl_set_presets(preset: str="Default"):
60
+ p = []
61
+ if preset in sdxl_preset_dict.keys(): p = sdxl_preset_dict[preset]
62
+ else: p = sdxl_preset_dict["Default"]
63
+ if len(p) != len(sdxl_preset_items): raise gr.Error("Invalid preset.")
64
+ print("Setting SDXL preset:", ", ".join([f"{x}:{y}" for x, y in zip(sdxl_preset_items, p)]))
65
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12]
66
+
67
+
68
+ sd15_vaes = [
69
+ "",
70
+ "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
71
+ "https://huggingface.co/stabilityai/sd-vae-ft-ema-original/resolve/main/vae-ft-ema-560000-ema-pruned.ckpt",
72
+ ]
73
+
74
+
75
+ sd15_loras = [
76
+ "",
77
+ "https://huggingface.co/SPO-Diffusion-Models/SPO-SD-v1-5_4k-p_10ep_LoRA/blob/main/spo-sd-v1-5_4k-p_10ep_lora_diffusers.safetensors",
78
+ ]
79
+
80
+
81
+ sd15_preset_items = ["dtype", "vae", "scheduler", "lora1", "lora1s", "lora2", "lora2s", "lora3", "lora3s", "lora4", "lora4s", "lora5", "lora5s", "ema"]
82
+ sd15_preset_dict = {
83
+ "Default": [DEFAULT_DTYPE, "", "Euler", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, True],
84
+ "Bake in standard VAE": [DEFAULT_DTYPE, "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt",
85
+ "Euler", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, True],
86
+ }
87
+
88
+
89
+ def sd15_set_presets(preset: str="Default"):
90
+ p = []
91
+ if preset in sd15_preset_dict.keys(): p = sd15_preset_dict[preset]
92
+ else: p = sd15_preset_dict["Default"]
93
+ if len(p) != len(sd15_preset_items): raise gr.Error("Invalid preset.")
94
+ print("Setting SD1.5 preset:", ", ".join([f"{x}:{y}" for x, y in zip(sd15_preset_items, p)]))
95
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
96
+
97
+
98
+ flux_vaes = [
99
+ "",
100
+ ]
101
+
102
+
103
+ flux_loras = [
104
+ "",
105
+ ]
106
+
107
+
108
+ flux_preset_items = ["dtype", "vae", "scheduler", "lora1", "lora1s", "lora2", "lora2s", "lora3", "lora3s", "lora4", "lora4s", "lora5", "lora5s", "base_repo"]
109
+ flux_preset_dict = {
110
+ "dev": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "camenduru/FLUX.1-dev-diffusers"],
111
+ "schnell": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "black-forest-labs/FLUX.1-schnell"],
112
+ }
113
+
114
+
115
+ def flux_set_presets(preset: str="dev"):
116
+ p = []
117
+ if preset in flux_preset_dict.keys(): p = flux_preset_dict[preset]
118
+ else: p = flux_preset_dict["dev"]
119
+ if len(p) != len(flux_preset_items): raise gr.Error("Invalid preset.")
120
+ print("Setting FLUX.1 preset:", ", ".join([f"{x}:{y}" for x, y in zip(flux_preset_items, p)]))
121
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
122
+
123
+
124
+
125
+ sd35_vaes = [
126
+ "",
127
+ ]
128
+
129
+
130
+ sd35_loras = [
131
+ "",
132
+ ]
133
+
134
+
135
+ sd35_preset_items = ["dtype", "vae", "scheduler", "lora1", "lora1s", "lora2", "lora2s", "lora3", "lora3s", "lora4", "lora4s", "lora5", "lora5s", "base_repo"]
136
+ sd35_preset_dict = {
137
+ "Default": ["bf16", "", "", "", 1.0, "", 1.0, "", 1.0, "", 1.0, "", 1.0, "adamo1139/stable-diffusion-3.5-large-ungated"],
138
+ }
139
+
140
+
141
+ def sd35_set_presets(preset: str="dev"):
142
+ p = []
143
+ if preset in sd35_preset_dict.keys(): p = sd35_preset_dict[preset]
144
+ else: p = sd35_preset_dict["Default"]
145
+ if len(p) != len(sd35_preset_items): raise gr.Error("Invalid preset.")
146
+ print("Setting SD3.5 preset:", ", ".join([f"{x}:{y}" for x, y in zip(sd35_preset_items, p)]))
147
+ return p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13]
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ safetensors
3
+ transformers==4.46.3
4
+ diffusers==0.31.0
5
+ peft
6
+ sentencepiece
7
+ torch==2.5.1
8
+ pytorch_lightning
9
+ gdown
10
+ bitsandbytes
11
+ accelerate
12
+ numpy<2
sdutils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from utils import get_download_file
4
+ from stkey import read_safetensors_key
5
+ try:
6
+ from diffusers import BitsAndBytesConfig
7
+ is_nf4 = True
8
+ except Exception:
9
+ is_nf4 = False
10
+
11
+
12
+ DTYPE_DEFAULT = "default"
13
+ DTYPE_DICT = {
14
+ "fp16": torch.float16,
15
+ "bf16": torch.bfloat16,
16
+ "fp32": torch.float32,
17
+ "fp8": torch.float8_e4m3fn,
18
+ }
19
+ #QTYPES = ["NF4"] if is_nf4 else []
20
+ QTYPES = []
21
+
22
+ def get_dtypes():
23
+ return list(DTYPE_DICT.keys()) + [DTYPE_DEFAULT] + QTYPES
24
+
25
+
26
+ def get_dtype(dtype: str):
27
+ if dtype in set(QTYPES): return torch.bfloat16
28
+ return DTYPE_DICT.get(dtype, torch.float16)
29
+
30
+
31
+ from diffusers import (
32
+ DPMSolverMultistepScheduler,
33
+ DPMSolverSinglestepScheduler,
34
+ KDPM2DiscreteScheduler,
35
+ EulerDiscreteScheduler,
36
+ EulerAncestralDiscreteScheduler,
37
+ HeunDiscreteScheduler,
38
+ LMSDiscreteScheduler,
39
+ DDIMScheduler,
40
+ DEISMultistepScheduler,
41
+ UniPCMultistepScheduler,
42
+ LCMScheduler,
43
+ PNDMScheduler,
44
+ KDPM2AncestralDiscreteScheduler,
45
+ DPMSolverSDEScheduler,
46
+ EDMDPMSolverMultistepScheduler,
47
+ DDPMScheduler,
48
+ EDMEulerScheduler,
49
+ TCDScheduler,
50
+ )
51
+
52
+
53
+ SCHEDULER_CONFIG_MAP = {
54
+ "DPM++ 2M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": False}),
55
+ "DPM++ 2M Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": True}),
56
+ "DPM++ 2M SDE": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
57
+ "DPM++ 2M SDE Karras": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
58
+ "DPM++ 2S": (DPMSolverSinglestepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": False}),
59
+ "DPM++ 2S Karras": (DPMSolverSinglestepScheduler, {"algorithm_type": "dpmsolver++", "use_karras_sigmas": True}),
60
+ "DPM++ 1S": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 1}),
61
+ "DPM++ 1S Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 1, "use_karras_sigmas": True}),
62
+ "DPM++ 3M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 3}),
63
+ "DPM++ 3M Karras": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "solver_order": 3, "use_karras_sigmas": True}),
64
+ "DPM 3M": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver", "final_sigmas_type": "sigma_min", "solver_order": 3}),
65
+ "DPM++ SDE": (DPMSolverSDEScheduler, {"use_karras_sigmas": False}),
66
+ "DPM++ SDE Karras": (DPMSolverSDEScheduler, {"use_karras_sigmas": True}),
67
+ "DPM2": (KDPM2DiscreteScheduler, {}),
68
+ "DPM2 Karras": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
69
+ "DPM2 a": (KDPM2AncestralDiscreteScheduler, {}),
70
+ "DPM2 a Karras": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
71
+ "Euler": (EulerDiscreteScheduler, {}),
72
+ "Euler a": (EulerAncestralDiscreteScheduler, {}),
73
+ "Euler trailing": (EulerDiscreteScheduler, {"timestep_spacing": "trailing", "prediction_type": "sample"}),
74
+ "Euler a trailing": (EulerAncestralDiscreteScheduler, {"timestep_spacing": "trailing"}),
75
+ "Heun": (HeunDiscreteScheduler, {}),
76
+ "Heun Karras": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
77
+ "LMS": (LMSDiscreteScheduler, {}),
78
+ "LMS Karras": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
79
+ "DDIM": (DDIMScheduler, {}),
80
+ "DDIM trailing": (DDIMScheduler, {"timestep_spacing": "trailing"}),
81
+ "DEIS": (DEISMultistepScheduler, {}),
82
+ "UniPC": (UniPCMultistepScheduler, {}),
83
+ "UniPC Karras": (UniPCMultistepScheduler, {"use_karras_sigmas": True}),
84
+ "PNDM": (PNDMScheduler, {}),
85
+ "Euler EDM": (EDMEulerScheduler, {}),
86
+ "Euler EDM Karras": (EDMEulerScheduler, {"use_karras_sigmas": True}),
87
+ "DPM++ 2M EDM": (EDMDPMSolverMultistepScheduler, {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}),
88
+ "DPM++ 2M EDM Karras": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}),
89
+ "DDPM": (DDPMScheduler, {}),
90
+
91
+ "DPM++ 2M Lu": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "use_lu_lambdas": True}),
92
+ "DPM++ 2M Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "dpmsolver++", "euler_at_final": True}),
93
+ "DPM++ 2M SDE Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}),
94
+ "DPM++ 2M SDE Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}),
95
+
96
+ "LCM": (LCMScheduler, {}),
97
+ "TCD": (TCDScheduler, {}),
98
+ "LCM trailing": (LCMScheduler, {"timestep_spacing": "trailing"}),
99
+ "TCD trailing": (TCDScheduler, {"timestep_spacing": "trailing"}),
100
+ "LCM Auto-Loader": (LCMScheduler, {}),
101
+ "TCD Auto-Loader": (TCDScheduler, {}),
102
+
103
+ "EDM": (EDMDPMSolverMultistepScheduler, {}),
104
+ "EDM Karras": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
105
+
106
+ "Euler (V-Prediction)": (EulerDiscreteScheduler, {"prediction_type": "v_prediction", "rescale_betas_zero_snr": True}),
107
+ "Euler a (V-Prediction)": (EulerAncestralDiscreteScheduler, {"prediction_type": "v_prediction", "rescale_betas_zero_snr": True}),
108
+ "Euler EDM (V-Prediction)": (EDMEulerScheduler, {"prediction_type": "v_prediction"}),
109
+ "Euler EDM Karras (V-Prediction)": (EDMEulerScheduler, {"use_karras_sigmas": True, "prediction_type": "v_prediction"}),
110
+ "DPM++ 2M EDM (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++", "prediction_type": "v_prediction"}),
111
+ "DPM++ 2M EDM Karras (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++", "prediction_type": "v_prediction"}),
112
+ "EDM (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"prediction_type": "v_prediction"}),
113
+ "EDM Karras (V-Prediction)": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "prediction_type": "v_prediction"}),
114
+ }
115
+
116
+
117
+ def get_scheduler_config(name: str):
118
+ if not name in SCHEDULER_CONFIG_MAP.keys(): return SCHEDULER_CONFIG_MAP["Euler a"]
119
+ return SCHEDULER_CONFIG_MAP[name]
120
+
121
+
122
+ def fuse_loras(pipe, lora_dict: dict, temp_dir: str, civitai_key: str="", dkwargs: dict={}):
123
+ if not lora_dict or not isinstance(lora_dict, dict): return pipe
124
+ a_list = []
125
+ w_list = []
126
+ for k, v in lora_dict.items():
127
+ if not k: continue
128
+ new_lora_file = get_download_file(temp_dir, k, civitai_key)
129
+ if not new_lora_file or not Path(new_lora_file).exists():
130
+ print(f"LoRA file not found: {k}")
131
+ continue
132
+ w_name = Path(new_lora_file).name
133
+ a_name = Path(new_lora_file).stem
134
+ pipe.load_lora_weights(new_lora_file, weight_name=w_name, adapter_name=a_name, low_cpu_mem_usage=False, **dkwargs)
135
+ a_list.append(a_name)
136
+ w_list.append(v)
137
+ if Path(new_lora_file).exists(): Path(new_lora_file).unlink()
138
+ if len(a_list) == 0: return pipe
139
+ pipe.set_adapters(a_list, adapter_weights=w_list)
140
+ pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
141
+ pipe.unload_lora_weights()
142
+ return pipe
143
+
144
+
145
+ MODEL_TYPE_KEY = {
146
+ "model.diffusion_model.output_blocks.1.1.norm.bias": "SDXL",
147
+ "model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "SD 1.5",
148
+ "double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
149
+ "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
150
+ "model.diffusion_model.joint_blocks.9.x_block.attn.ln_k.weight": "SD 3.5",
151
+ }
152
+
153
+
154
+ def get_model_type_from_key(path: str):
155
+ default = "SDXL"
156
+ try:
157
+ keys = read_safetensors_key(path)
158
+ for k, v in MODEL_TYPE_KEY.items():
159
+ if k in set(keys):
160
+ print(f"Model type is {v}.")
161
+ return v
162
+ print("Model type could not be identified.")
163
+ except Exception:
164
+ return default
165
+ return default
166
+
167
+
168
+ def get_process_dtype(dtype: str, model_type: str):
169
+ if dtype in set(["fp8"] + QTYPES): return torch.bfloat16 if model_type in ["FLUX", "SD 3.5"] else torch.float16
170
+ return DTYPE_DICT.get(dtype, torch.float16)
stkey.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import json
4
+ import re
5
+ import gc
6
+ from safetensors.torch import load_file, save_file
7
+ import torch
8
+
9
+
10
+ SDXL_KEYS_FILE = "keys/sdxl_keys.txt"
11
+
12
+
13
+ def list_uniq(l):
14
+ return sorted(set(l), key=l.index)
15
+
16
+
17
+ def read_safetensors_metadata(path: str):
18
+ with open(path, 'rb') as f:
19
+ header_size = int.from_bytes(f.read(8), 'little')
20
+ header_json = f.read(header_size).decode('utf-8')
21
+ header = json.loads(header_json)
22
+ metadata = header.get('__metadata__', {})
23
+ return metadata
24
+
25
+
26
+ def keys_from_file(path: str):
27
+ keys = []
28
+ try:
29
+ with open(str(Path(path)), encoding='utf-8', mode='r') as f:
30
+ lines = f.readlines()
31
+ for line in lines:
32
+ keys.append(line.strip())
33
+ except Exception as e:
34
+ print(e)
35
+ finally:
36
+ return keys
37
+
38
+
39
+ def validate_keys(keys: list[str], rfile: str=SDXL_KEYS_FILE):
40
+ missing = []
41
+ added = []
42
+ try:
43
+ rkeys = keys_from_file(rfile)
44
+ all_keys = list_uniq(keys + rkeys)
45
+ for key in all_keys:
46
+ if key in set(rkeys) and key not in set(keys): missing.append(key)
47
+ if key in set(keys) and key not in set(rkeys): added.append(key)
48
+ except Exception as e:
49
+ print(e)
50
+ finally:
51
+ return missing, added
52
+
53
+
54
+ def read_safetensors_key(path: str):
55
+ try:
56
+ keys = []
57
+ state_dict = load_file(str(Path(path)))
58
+ for k in list(state_dict.keys()):
59
+ keys.append(k)
60
+ state_dict.pop(k)
61
+ except Exception as e:
62
+ print(e)
63
+ finally:
64
+ del state_dict
65
+ torch.cuda.empty_cache()
66
+ gc.collect()
67
+ return keys
68
+
69
+
70
+ def write_safetensors_key(keys: list[str], path: str, is_validate: bool=True, rpath: str=SDXL_KEYS_FILE):
71
+ if len(keys) == 0: return False
72
+ try:
73
+ with open(str(Path(path)), encoding='utf-8', mode='w') as f:
74
+ f.write("\n".join(keys))
75
+ if is_validate:
76
+ missing, added = validate_keys(keys, rpath)
77
+ with open(str(Path(path).stem + "_missing.txt"), encoding='utf-8', mode='w') as f:
78
+ f.write("\n".join(missing))
79
+ with open(str(Path(path).stem + "_added.txt"), encoding='utf-8', mode='w') as f:
80
+ f.write("\n".join(added))
81
+ return True
82
+ except Exception as e:
83
+ print(e)
84
+ return False
85
+
86
+
87
+ def stkey(input: str, out_filename: str="", is_validate: bool=True, rfile: str=SDXL_KEYS_FILE):
88
+ keys = read_safetensors_key(input)
89
+ if len(keys) != 0 and out_filename: write_safetensors_key(keys, out_filename, is_validate, rfile)
90
+ if len(keys) != 0:
91
+ print("Metadata:")
92
+ print(read_safetensors_metadata(input))
93
+ print("\nKeys:")
94
+ print("\n".join(keys))
95
+ if is_validate:
96
+ missing, added = validate_keys(keys, rfile)
97
+ print("\nMissing Keys:")
98
+ print("\n".join(missing))
99
+ print("\nAdded Keys:")
100
+ print("\n".join(added))
101
+
102
+
103
+ if __name__ == "__main__":
104
+ parser = argparse.ArgumentParser()
105
+ parser.add_argument("input", type=str, help="Input safetensors file.")
106
+ parser.add_argument("-s", "--save", action="store_true", default=False, help="Output to text file.")
107
+ parser.add_argument("-o", "--output", default="", type=str, help="Output to specific text file.")
108
+ parser.add_argument("-v", "--val", action="store_false", default=True, help="Disable key validation.")
109
+ parser.add_argument("-r", "--rfile", default=SDXL_KEYS_FILE, type=str, help="Specify reference file to validate keys.")
110
+
111
+ args = parser.parse_args()
112
+
113
+ if args.save: out_filename = Path(args.input).stem + ".txt"
114
+ out_filename = args.output if args.output else out_filename
115
+
116
+ stkey(args.input, out_filename, args.val, args.rfile)
117
+
118
+
119
+ # Usage:
120
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors
121
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors -s
122
+ # python stkey.py sd_xl_base_1.0_0.9vae.safetensors -o key.txt
utils.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download, snapshot_download
3
+ import os
4
+ from pathlib import Path
5
+ import shutil
6
+ import gc
7
+ import re
8
+ import urllib.parse
9
+ import subprocess
10
+ import time
11
+ from typing import Any
12
+
13
+
14
+ def get_token():
15
+ try:
16
+ token = HfFolder.get_token()
17
+ except Exception:
18
+ token = ""
19
+ return token
20
+
21
+
22
+ def set_token(token):
23
+ try:
24
+ HfFolder.save_token(token)
25
+ except Exception:
26
+ print(f"Error: Failed to save token.")
27
+
28
+
29
+ def get_state(state: dict, key: str):
30
+ if key in state.keys(): return state[key]
31
+ else:
32
+ print(f"State '{key}' not found.")
33
+ return None
34
+
35
+
36
+ def set_state(state: dict, key: str, value: Any):
37
+ state[key] = value
38
+
39
+
40
+ def get_user_agent():
41
+ return 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0'
42
+
43
+
44
+ def is_repo_exists(repo_id: str, repo_type: str="model"):
45
+ hf_token = get_token()
46
+ api = HfApi(token=hf_token)
47
+ try:
48
+ if api.repo_exists(repo_id=repo_id, repo_type=repo_type, token=hf_token): return True
49
+ else: return False
50
+ except Exception as e:
51
+ print(f"Error: Failed to connect {repo_id} ({repo_type}). {e}")
52
+ return True # for safe
53
+
54
+
55
+ MODEL_TYPE_CLASS = {
56
+ "diffusers:StableDiffusionPipeline": "SD 1.5",
57
+ "diffusers:StableDiffusionXLPipeline": "SDXL",
58
+ "diffusers:FluxPipeline": "FLUX",
59
+ }
60
+
61
+
62
+ def get_model_type(repo_id: str):
63
+ hf_token = get_token()
64
+ api = HfApi(token=hf_token)
65
+ lora_filename = "pytorch_lora_weights.safetensors"
66
+ diffusers_filename = "model_index.json"
67
+ default = "SDXL"
68
+ try:
69
+ if api.file_exists(repo_id=repo_id, filename=lora_filename, token=hf_token): return "LoRA"
70
+ if not api.file_exists(repo_id=repo_id, filename=diffusers_filename, token=hf_token): return "None"
71
+ model = api.model_info(repo_id=repo_id, token=hf_token)
72
+ tags = model.tags
73
+ for tag in tags:
74
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
75
+ except Exception:
76
+ return default
77
+ return default
78
+
79
+
80
+ def list_uniq(l):
81
+ return sorted(set(l), key=l.index)
82
+
83
+
84
+ def list_sub(a, b):
85
+ return [e for e in a if e not in b]
86
+
87
+
88
+ def is_repo_name(s):
89
+ return re.fullmatch(r'^[\w_\-\.]+/[\w_\-\.]+$', s)
90
+
91
+
92
+ def get_hf_url(repo_id: str, repo_type: str="model"):
93
+ if repo_type == "dataset": url = f"https://huggingface.co/datasets/{repo_id}"
94
+ elif repo_type == "space": url = f"https://huggingface.co/spaces/{repo_id}"
95
+ else: url = f"https://huggingface.co/{repo_id}"
96
+ return url
97
+
98
+
99
+ def split_hf_url(url: str):
100
+ try:
101
+ s = list(re.findall(r'^(?:https?://huggingface.co/)(?:(datasets|spaces)/)?(.+?/.+?)/\w+?/.+?/(?:(.+)/)?(.+?.\w+)(?:\?download=true)?$', url)[0])
102
+ if len(s) < 4: return "", "", "", ""
103
+ repo_id = s[1]
104
+ if s[0] == "datasets": repo_type = "dataset"
105
+ elif s[0] == "spaces": repo_type = "space"
106
+ else: repo_type = "model"
107
+ subfolder = urllib.parse.unquote(s[2]) if s[2] else None
108
+ filename = urllib.parse.unquote(s[3])
109
+ return repo_id, filename, subfolder, repo_type
110
+ except Exception as e:
111
+ print(e)
112
+
113
+
114
+ def download_hf_file(directory, url, progress=gr.Progress(track_tqdm=True)):
115
+ hf_token = get_token()
116
+ repo_id, filename, subfolder, repo_type = split_hf_url(url)
117
+ try:
118
+ print(f"Downloading {url} to {directory}")
119
+ if subfolder is not None: path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, local_dir=directory, token=hf_token)
120
+ else: path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type, local_dir=directory, token=hf_token)
121
+ return path
122
+ except Exception as e:
123
+ print(f"Failed to download: {e}")
124
+ return None
125
+
126
+
127
+ def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): # requires aria2, gdown
128
+ try:
129
+ url = url.strip()
130
+ if "drive.google.com" in url:
131
+ original_dir = os.getcwd()
132
+ os.chdir(directory)
133
+ subprocess.run(f"gdown --fuzzy {url}", shell=True)
134
+ os.chdir(original_dir)
135
+ elif "huggingface.co" in url:
136
+ url = url.replace("?download=true", "")
137
+ if "/blob/" in url: url = url.replace("/blob/", "/resolve/")
138
+ download_hf_file(directory, url)
139
+ elif "civitai.com" in url:
140
+ if civitai_api_key:
141
+ url = f"'{url}&token={civitai_api_key}'" if "?" in url else f"{url}?token={civitai_api_key}"
142
+ print(f"Downloading {url}")
143
+ subprocess.run(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}", shell=True)
144
+ else:
145
+ print("You need an API key to download Civitai models.")
146
+ else:
147
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
148
+ except Exception as e:
149
+ print(f"Failed to download: {e}")
150
+
151
+
152
+ def get_local_file_list(dir_path):
153
+ file_list = []
154
+ for file in Path(dir_path).glob("**/*.*"):
155
+ if file.is_file():
156
+ file_path = str(file)
157
+ file_list.append(file_path)
158
+ return file_list
159
+
160
+
161
+ def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
162
+ try:
163
+ if not "http" in url and is_repo_name(url) and not Path(url).exists():
164
+ print(f"Use HF Repo: {url}")
165
+ new_file = url
166
+ elif not "http" in url and Path(url).exists():
167
+ print(f"Use local file: {url}")
168
+ new_file = url
169
+ elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
170
+ print(f"File to download alreday exists: {url}")
171
+ new_file = f"{temp_dir}/{url.split('/')[-1]}"
172
+ else:
173
+ print(f"Start downloading: {url}")
174
+ before = get_local_file_list(temp_dir)
175
+ download_thing(temp_dir, url.strip(), civitai_key)
176
+ after = get_local_file_list(temp_dir)
177
+ new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
178
+ if not new_file:
179
+ print(f"Download failed: {url}")
180
+ return ""
181
+ print(f"Download completed: {url}")
182
+ return new_file
183
+ except Exception as e:
184
+ print(f"Download failed: {url} {e}")
185
+ return ""
186
+
187
+
188
+ def download_repo(repo_id: str, dir_path: str, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
189
+ hf_token = get_token()
190
+ try:
191
+ snapshot_download(repo_id=repo_id, local_dir=dir_path, token=hf_token, allow_patterns=["*.safetensors", "*.bin"],
192
+ ignore_patterns=["*.fp16.*", "/*.safetensors", "/*.bin"], force_download=True)
193
+ return True
194
+ except Exception as e:
195
+ print(f"Error: Failed to download {repo_id}. {e}")
196
+ gr.Warning(f"Error: Failed to download {repo_id}. {e}")
197
+ return False
198
+
199
+
200
+ def upload_repo(repo_id: str, dir_path: str, is_private: bool, is_pr: bool=False, progress=gr.Progress(track_tqdm=True)): # for diffusers repo
201
+ hf_token = get_token()
202
+ api = HfApi(token=hf_token)
203
+ try:
204
+ progress(0, desc="Start uploading...")
205
+ api.create_repo(repo_id=repo_id, token=hf_token, private=is_private, exist_ok=True)
206
+ api.upload_folder(repo_id=repo_id, folder_path=dir_path, path_in_repo="", create_pr=is_pr, token=hf_token)
207
+ progress(1, desc="Uploaded.")
208
+ return get_hf_url(repo_id, "model")
209
+ except Exception as e:
210
+ print(f"Error: Failed to upload to {repo_id}. {e}")
211
+ return ""
212
+
213
+
214
+ def gate_repo(repo_id: str, gated_str: str, repo_type: str="model"):
215
+ hf_token = get_token()
216
+ api = HfApi(token=hf_token)
217
+ try:
218
+ if gated_str == "auto": gated = "auto"
219
+ elif gated_str == "manual": gated = "manual"
220
+ else: gated = False
221
+ api.update_repo_settings(repo_id=repo_id, gated=gated, repo_type=repo_type, token=hf_token)
222
+ except Exception as e:
223
+ print(f"Error: Failed to update settings {repo_id}. {e}")
224
+
225
+
226
+ HF_SUBFOLDER_NAME = ["None", "user_repo"]
227
+
228
+
229
+ def duplicate_hf_repo(src_repo: str, dst_repo: str, src_repo_type: str, dst_repo_type: str,
230
+ is_private: bool, subfolder_type: str=HF_SUBFOLDER_NAME[1], progress=gr.Progress(track_tqdm=True)):
231
+ hf_token = get_token()
232
+ api = HfApi(token=hf_token)
233
+ try:
234
+ if subfolder_type == "user_repo": subfolder = src_repo.replace("/", "_")
235
+ else: subfolder = ""
236
+ progress(0, desc="Start duplicating...")
237
+ api.create_repo(repo_id=dst_repo, repo_type=dst_repo_type, private=is_private, exist_ok=True, token=hf_token)
238
+ for path in api.list_repo_files(repo_id=src_repo, repo_type=src_repo_type, token=hf_token):
239
+ file = hf_hub_download(repo_id=src_repo, filename=path, repo_type=src_repo_type, token=hf_token)
240
+ if not Path(file).exists(): continue
241
+ if Path(file).is_dir(): # unused for now
242
+ api.upload_folder(repo_id=dst_repo, folder_path=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
243
+ repo_type=dst_repo_type, token=hf_token)
244
+ elif Path(file).is_file():
245
+ api.upload_file(repo_id=dst_repo, path_or_fileobj=file, path_in_repo=f"{subfolder}/{path}" if subfolder else path,
246
+ repo_type=dst_repo_type, token=hf_token)
247
+ if Path(file).exists(): Path(file).unlink()
248
+ progress(1, desc="Duplicated.")
249
+ return f"{get_hf_url(dst_repo, dst_repo_type)}/tree/main/{subfolder}" if subfolder else get_hf_url(dst_repo, dst_repo_type)
250
+ except Exception as e:
251
+ print(f"Error: Failed to duplicate repo {src_repo} to {dst_repo}. {e}")
252
+ return ""
253
+
254
+
255
+ BASE_DIR = str(Path(__file__).resolve().parent.resolve())
256
+ CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY")
257
+
258
+
259
+ def get_file(url: str, path: str): # requires aria2, gdown
260
+ print(f"Downloading {url} to {path}...")
261
+ get_download_file(path, url, CIVITAI_API_KEY)
262
+
263
+
264
+ def git_clone(url: str, path: str, pip: bool=False, addcmd: str=""): # requires git
265
+ os.makedirs(str(Path(BASE_DIR, path)), exist_ok=True)
266
+ os.chdir(Path(BASE_DIR, path))
267
+ print(f"Cloning {url} to {path}...")
268
+ cmd = f'git clone {url}'
269
+ print(f'Running {cmd} at {Path.cwd()}')
270
+ i = subprocess.run(cmd, shell=True).returncode
271
+ if i != 0: print(f'Error occured at running {cmd}')
272
+ p = url.split("/")[-1]
273
+ if not Path(p).exists: return
274
+ if pip:
275
+ os.chdir(Path(BASE_DIR, path, p))
276
+ cmd = f'pip install -r requirements.txt'
277
+ print(f'Running {cmd} at {Path.cwd()}')
278
+ i = subprocess.run(cmd, shell=True).returncode
279
+ if i != 0: print(f'Error occured at running {cmd}')
280
+ if addcmd:
281
+ os.chdir(Path(BASE_DIR, path, p))
282
+ cmd = addcmd
283
+ print(f'Running {cmd} at {Path.cwd()}')
284
+ i = subprocess.run(cmd, shell=True).returncode
285
+ if i != 0: print(f'Error occured at running {cmd}')
286
+
287
+
288
+ def run(cmd: str, timeout: float=0):
289
+ print(f'Running {cmd} at {Path.cwd()}')
290
+ if timeout == 0:
291
+ i = subprocess.run(cmd, shell=True).returncode
292
+ if i != 0: print(f'Error occured at running {cmd}')
293
+ else:
294
+ p = subprocess.Popen(cmd, shell=True)
295
+ time.sleep(timeout)
296
+ p.terminate()
297
+ print(f'Terminated in {timeout} seconds')