|
import os
|
|
import spaces
|
|
import argparse
|
|
from pathlib import Path
|
|
import os
|
|
import torch
|
|
from diffusers import (DiffusionPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler, StableDiffusionXLPipeline, StableDiffusionPipeline,
|
|
FluxPipeline, FluxTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline)
|
|
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection, AutoTokenizer, T5EncoderModel, BitsAndBytesConfig as TFBitsAndBytesConfig
|
|
from huggingface_hub import save_torch_state_dict, snapshot_download
|
|
from diffusers.loaders.single_file_utils import (convert_flux_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers,
|
|
convert_sd3_t5_checkpoint_to_diffusers)
|
|
import safetensors.torch
|
|
import gradio as gr
|
|
import shutil
|
|
import gc
|
|
import tempfile
|
|
|
|
from utils import (get_token, set_token, is_repo_exists, is_repo_name, get_download_file, upload_repo)
|
|
from sdutils import (SCHEDULER_CONFIG_MAP, get_scheduler_config, fuse_loras, DTYPE_DEFAULT, get_dtype, get_dtypes, get_model_type_from_key, get_process_dtype)
|
|
|
|
|
|
@spaces.GPU
|
|
def fake_gpu():
|
|
pass
|
|
|
|
|
|
try:
|
|
from diffusers import BitsAndBytesConfig
|
|
is_nf4 = True
|
|
except Exception:
|
|
is_nf4 = False
|
|
|
|
|
|
FLUX_BASE_REPOS = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell", "John6666/flux1-dev-fp8-flux", "John6666/flux1-schnell-fp8-flux"]
|
|
FLUX_T5_URL = "https://huggingface.co/camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors"
|
|
SD35_BASE_REPOS = ["adamo1139/stable-diffusion-3.5-large-ungated", "adamo1139/stable-diffusion-3.5-large-turbo-ungated"]
|
|
SD35_T5_URL = "https://huggingface.co/adamo1139/stable-diffusion-3.5-large-turbo-ungated/blob/main/text_encoders/t5xxl_fp8_e4m3fn.safetensors"
|
|
TEMP_DIR = tempfile.mkdtemp()
|
|
IS_ZERO = os.environ.get("SPACES_ZERO_GPU") is not None
|
|
IS_CUDA = torch.cuda.is_available()
|
|
|
|
|
|
def safe_clean(path: str):
|
|
try:
|
|
if Path(path).exists():
|
|
if Path(path).is_dir(): shutil.rmtree(str(Path(path)))
|
|
else: Path(path).unlink()
|
|
print(f"Deleted: {path}")
|
|
else: print(f"File not found: {path}")
|
|
except Exception as e:
|
|
print(f"Failed to delete: {path} {e}")
|
|
|
|
|
|
def save_readme_md(dir, url):
|
|
orig_url = ""
|
|
orig_name = ""
|
|
if is_repo_name(url):
|
|
orig_name = url
|
|
orig_url = f"https://huggingface.co/{url}/"
|
|
elif "http" in url:
|
|
orig_name = url
|
|
orig_url = url
|
|
if orig_name and orig_url:
|
|
md = f"""---
|
|
license: other
|
|
language:
|
|
- en
|
|
library_name: diffusers
|
|
pipeline_tag: text-to-image
|
|
tags:
|
|
- text-to-image
|
|
---
|
|
Converted from [{orig_name}]({orig_url}).
|
|
"""
|
|
else:
|
|
md = f"""---
|
|
license: other
|
|
language:
|
|
- en
|
|
library_name: diffusers
|
|
pipeline_tag: text-to-image
|
|
tags:
|
|
- text-to-image
|
|
---
|
|
"""
|
|
path = str(Path(dir, "README.md"))
|
|
with open(path, mode='w', encoding="utf-8") as f:
|
|
f.write(md)
|
|
|
|
|
|
def save_module(model, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)):
|
|
if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
|
else: pattern = "model{suffix}.safetensors"
|
|
if name in ["transformer", "unet"]: size = "10GB"
|
|
else: size = "5GB"
|
|
path = str(Path(f"{dir.removesuffix('/')}/{name}"))
|
|
os.makedirs(path, exist_ok=True)
|
|
progress(0, desc=f"Saving {name} to {dir}...")
|
|
print(f"Saving {name} to {dir}...")
|
|
model.to("cpu")
|
|
sd = dict(model.state_dict())
|
|
new_sd = {}
|
|
for key in list(sd.keys()):
|
|
q = sd.pop(key)
|
|
if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn)
|
|
else: new_sd[key] = q
|
|
del sd
|
|
gc.collect()
|
|
save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size)
|
|
del new_sd
|
|
gc.collect()
|
|
|
|
|
|
def save_module_sd(sd: dict, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)):
|
|
if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
|
else: pattern = "model{suffix}.safetensors"
|
|
if name in ["transformer", "unet"]: size = "10GB"
|
|
else: size = "5GB"
|
|
path = str(Path(f"{dir.removesuffix('/')}/{name}"))
|
|
os.makedirs(path, exist_ok=True)
|
|
progress(0, desc=f"Saving state_dict of {name} to {dir}...")
|
|
print(f"Saving state_dict of {name} to {dir}...")
|
|
new_sd = {}
|
|
for key in list(sd.keys()):
|
|
q = sd.pop(key).to("cpu")
|
|
if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn)
|
|
else: new_sd[key] = q
|
|
save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size)
|
|
del new_sd
|
|
gc.collect()
|
|
|
|
|
|
def convert_flux_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)):
|
|
temp_dir = TEMP_DIR
|
|
down_dir = str(Path(f"{TEMP_DIR}/down"))
|
|
os.makedirs(down_dir, exist_ok=True)
|
|
hf_token = get_token()
|
|
progress(0.25, desc=f"Loading {new_file}...")
|
|
orig_sd = safetensors.torch.load_file(new_file)
|
|
progress(0.3, desc=f"Converting {new_file}...")
|
|
conv_sd = convert_flux_transformer_checkpoint_to_diffusers(orig_sd)
|
|
del orig_sd
|
|
gc.collect()
|
|
progress(0.35, desc=f"Saving {new_file}...")
|
|
save_module_sd(conv_sd, "transformer", new_dir, dtype)
|
|
del conv_sd
|
|
gc.collect()
|
|
progress(0.5, desc=f"Loading text_encoder_2 from {FLUX_T5_URL}...")
|
|
t5_file = get_download_file(temp_dir, FLUX_T5_URL, civitai_key)
|
|
if not t5_file: raise Exception(f"Safetensors file not found: {FLUX_T5_URL}")
|
|
t5_sd = safetensors.torch.load_file(t5_file)
|
|
safe_clean(t5_file)
|
|
save_module_sd(t5_sd, "text_encoder_2", new_dir, dtype)
|
|
del t5_sd
|
|
gc.collect()
|
|
progress(0.6, desc=f"Loading other components from {base_repo}...")
|
|
pipe = FluxPipeline.from_pretrained(base_repo, transformer=None, text_encoder_2=None, use_safetensors=True, **kwargs,
|
|
torch_dtype=torch.bfloat16, token=hf_token)
|
|
pipe.save_pretrained(new_dir)
|
|
progress(0.75, desc=f"Loading nontensor files from {base_repo}...")
|
|
snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True,
|
|
ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
|
|
shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
|
|
safe_clean(down_dir)
|
|
|
|
|
|
def convert_sd35_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)):
|
|
temp_dir = TEMP_DIR
|
|
down_dir = str(Path(f"{TEMP_DIR}/down"))
|
|
os.makedirs(down_dir, exist_ok=True)
|
|
hf_token = get_token()
|
|
progress(0.25, desc=f"Loading {new_file}...")
|
|
orig_sd = safetensors.torch.load_file(new_file)
|
|
progress(0.3, desc=f"Converting {new_file}...")
|
|
conv_sd = convert_sd3_transformer_checkpoint_to_diffusers(orig_sd)
|
|
del orig_sd
|
|
gc.collect()
|
|
progress(0.35, desc=f"Saving {new_file}...")
|
|
save_module_sd(conv_sd, "transformer", new_dir, dtype)
|
|
del conv_sd
|
|
gc.collect()
|
|
progress(0.5, desc=f"Loading text_encoder_3 from {SD35_T5_URL}...")
|
|
t5_file = get_download_file(temp_dir, SD35_T5_URL, civitai_key)
|
|
if not t5_file: raise Exception(f"Safetensors file not found: {SD35_T5_URL}")
|
|
t5_sd = safetensors.torch.load_file(t5_file)
|
|
safe_clean(t5_file)
|
|
conv_t5_sd = convert_sd3_t5_checkpoint_to_diffusers(t5_sd)
|
|
del t5_sd
|
|
gc.collect()
|
|
save_module_sd(conv_t5_sd, "text_encoder_3", new_dir, dtype)
|
|
del conv_t5_sd
|
|
gc.collect()
|
|
progress(0.6, desc=f"Loading other components from {base_repo}...")
|
|
pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=None, text_encoder_3=None, use_safetensors=True, **kwargs,
|
|
torch_dtype=torch.bfloat16, token=hf_token)
|
|
pipe.save_pretrained(new_dir)
|
|
progress(0.75, desc=f"Loading nontensor files from {base_repo}...")
|
|
snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True,
|
|
ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
|
|
shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
|
|
safe_clean(down_dir)
|
|
|
|
|
|
|
|
def load_and_save_pipeline(pipe, model_type: str, url: str, new_file: str, new_dir: str, dtype: str,
|
|
scheduler: str, ema: bool, base_repo: str, civitai_key: str, lora_dict: dict,
|
|
my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder,
|
|
kwargs: dict, dkwargs: dict, progress=gr.Progress(track_tqdm=True)):
|
|
try:
|
|
hf_token = get_token()
|
|
temp_dir = TEMP_DIR
|
|
qkwargs = {}
|
|
tfqkwargs = {}
|
|
if is_nf4:
|
|
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
|
nf4_config_tf = TFBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
|
else:
|
|
nf4_config = None
|
|
nf4_config_tf = None
|
|
if dtype == "NF4" and nf4_config is not None and nf4_config_tf is not None:
|
|
qkwargs["quantization_config"] = nf4_config
|
|
tfqkwargs["quantization_config"] = nf4_config_tf
|
|
|
|
|
|
|
|
if model_type == "SDXL":
|
|
if is_repo_name(url): pipe = StableDiffusionXLPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
|
|
else: pipe = StableDiffusionXLPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs)
|
|
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
|
|
sconf = get_scheduler_config(scheduler)
|
|
pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
|
|
pipe.save_pretrained(new_dir)
|
|
elif model_type == "SD 1.5":
|
|
if is_repo_name(url): pipe = StableDiffusionPipeline.from_pretrained(url, extract_ema=ema, requires_safety_checker=False,
|
|
use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
|
|
else: pipe = StableDiffusionPipeline.from_single_file(new_file, extract_ema=ema, requires_safety_checker=False, use_safetensors=True, **kwargs, **dkwargs)
|
|
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
|
|
sconf = get_scheduler_config(scheduler)
|
|
pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1])
|
|
pipe.save_pretrained(new_dir)
|
|
elif model_type == "FLUX":
|
|
if dtype != "fp8":
|
|
if is_repo_name(url):
|
|
transformer = FluxTransformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
|
|
|
|
|
|
|
|
pipe = FluxPipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
|
|
else:
|
|
transformer = FluxTransformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
|
|
|
|
|
|
|
|
pipe = FluxPipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
|
|
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
|
|
pipe.save_pretrained(new_dir)
|
|
elif not is_repo_name(url): convert_flux_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs)
|
|
elif model_type == "SD 3.5":
|
|
if dtype != "fp8":
|
|
if is_repo_name(url):
|
|
transformer = SD3Transformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
|
|
|
|
|
|
|
|
pipe = StableDiffusion3Pipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
|
|
else:
|
|
transformer = SD3Transformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs)
|
|
|
|
|
|
|
|
pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
|
|
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
|
|
pipe.save_pretrained(new_dir)
|
|
elif not is_repo_name(url): convert_sd35_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs)
|
|
else:
|
|
if is_repo_name(url): pipe = DiffusionPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token)
|
|
else: pipe = DiffusionPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs)
|
|
pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs)
|
|
pipe.save_pretrained(new_dir)
|
|
except Exception as e:
|
|
print(f"Failed to load pipeline. {e}")
|
|
raise Exception("Failed to load pipeline.") from e
|
|
finally:
|
|
return pipe
|
|
|
|
|
|
def convert_url_to_diffusers(url: str, civitai_key: str="", is_upload_sf: bool=False, dtype: str="fp16", vae: str="", clip: str="", t5: str="",
|
|
scheduler: str="Euler a", ema: bool=True, base_repo: str="", mtype: str="", lora_dict: dict={}, is_local: bool=True, progress=gr.Progress(track_tqdm=True)):
|
|
try:
|
|
hf_token = get_token()
|
|
progress(0, desc="Start converting...")
|
|
temp_dir = TEMP_DIR
|
|
|
|
if is_repo_name(url) and is_repo_exists(url):
|
|
new_file = url
|
|
model_type = mtype
|
|
else:
|
|
new_file = get_download_file(temp_dir, url, civitai_key)
|
|
if not new_file: raise Exception(f"Safetensors file not found: {url}")
|
|
model_type = get_model_type_from_key(new_file)
|
|
new_dir = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_")
|
|
|
|
kwargs = {}
|
|
dkwargs = {}
|
|
if dtype != DTYPE_DEFAULT: dkwargs["torch_dtype"] = get_process_dtype(dtype, model_type)
|
|
pipe = None
|
|
|
|
print(f"Model type: {model_type} / VAE: {vae} / CLIP: {clip} / T5: {t5} / Scheduler: {scheduler} / dtype: {dtype} / EMA: {ema} / Base repo: {base_repo} / LoRAs: {lora_dict}")
|
|
|
|
my_vae = None
|
|
if vae:
|
|
progress(0, desc=f"Loading VAE: {vae}...")
|
|
if is_repo_name(vae): my_vae = AutoencoderKL.from_pretrained(vae, **dkwargs, token=hf_token)
|
|
else:
|
|
new_vae_file = get_download_file(temp_dir, vae, civitai_key)
|
|
my_vae = AutoencoderKL.from_single_file(new_vae_file, **dkwargs) if new_vae_file else None
|
|
safe_clean(new_vae_file)
|
|
if my_vae: kwargs["vae"] = my_vae
|
|
|
|
my_clip_tokenizer = None
|
|
my_clip_encoder = None
|
|
if clip:
|
|
progress(0, desc=f"Loading CLIP: {clip}...")
|
|
if is_repo_name(clip):
|
|
my_clip_tokenizer = CLIPTokenizer.from_pretrained(clip, token=hf_token)
|
|
if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_pretrained(clip, **dkwargs, token=hf_token)
|
|
else: my_clip_encoder = CLIPTextModel.from_pretrained(clip, **dkwargs, token=hf_token)
|
|
else:
|
|
new_clip_file = get_download_file(temp_dir, clip, civitai_key)
|
|
if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None
|
|
else: my_clip_encoder = CLIPTextModel.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None
|
|
safe_clean(new_clip_file)
|
|
if model_type == "SD 3.5":
|
|
if my_clip_tokenizer:
|
|
kwargs["tokenizer"] = my_clip_tokenizer
|
|
kwargs["tokenizer_2"] = my_clip_tokenizer
|
|
if my_clip_encoder:
|
|
kwargs["text_encoder"] = my_clip_encoder
|
|
kwargs["text_encoder_2"] = my_clip_encoder
|
|
else:
|
|
if my_clip_tokenizer: kwargs["tokenizer"] = my_clip_tokenizer
|
|
if my_clip_encoder: kwargs["text_encoder"] = my_clip_encoder
|
|
|
|
my_t5_tokenizer = None
|
|
my_t5_encoder = None
|
|
if t5:
|
|
progress(0, desc=f"Loading T5: {t5}...")
|
|
if is_repo_name(t5):
|
|
my_t5_tokenizer = AutoTokenizer.from_pretrained(t5, token=hf_token)
|
|
my_t5_encoder = T5EncoderModel.from_pretrained(t5, **dkwargs, token=hf_token)
|
|
else:
|
|
new_t5_file = get_download_file(temp_dir, t5, civitai_key)
|
|
my_t5_encoder = T5EncoderModel.from_single_file(new_t5_file, **dkwargs) if new_t5_file else None
|
|
safe_clean(new_t5_file)
|
|
if model_type == "SD 3.5":
|
|
if my_t5_tokenizer: kwargs["tokenizer_3"] = my_t5_tokenizer
|
|
if my_t5_encoder: kwargs["text_encoder_3"] = my_t5_encoder
|
|
else:
|
|
if my_t5_tokenizer: kwargs["tokenizer_2"] = my_t5_tokenizer
|
|
if my_t5_encoder: kwargs["text_encoder_2"] = my_t5_encoder
|
|
|
|
pipe = load_and_save_pipeline(pipe, model_type, url, new_file, new_dir, dtype, scheduler, ema, base_repo, civitai_key, lora_dict,
|
|
my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder, kwargs, dkwargs)
|
|
|
|
if Path(new_dir).exists(): save_readme_md(new_dir, url)
|
|
|
|
if not is_local:
|
|
if not is_repo_name(new_file) and is_upload_sf: shutil.move(str(Path(new_file).resolve()), str(Path(new_dir, Path(new_file).name).resolve()))
|
|
else: safe_clean(new_file)
|
|
|
|
progress(1, desc="Converted.")
|
|
return new_dir
|
|
except Exception as e:
|
|
print(f"Failed to convert. {e}")
|
|
raise Exception("Failed to convert.") from e
|
|
finally:
|
|
del pipe
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
def convert_url_to_diffusers_repo(dl_url: str, hf_user: str, hf_repo: str, hf_token: str, civitai_key="", is_private: bool=True, is_overwrite: bool=False,
|
|
is_upload_sf: bool=False, urls: list=[], dtype: str="fp16", vae: str="", clip: str="", t5: str="", scheduler: str="Euler a", ema: bool=True,
|
|
base_repo: str="", mtype: str="", lora1: str="", lora1s=1.0, lora2: str="", lora2s=1.0, lora3: str="", lora3s=1.0,
|
|
lora4: str="", lora4s=1.0, lora5: str="", lora5s=1.0, progress=gr.Progress(track_tqdm=True)):
|
|
try:
|
|
is_local = False
|
|
if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY")
|
|
if not hf_token and os.environ.get("HF_TOKEN"): hf_token = os.environ.get("HF_TOKEN")
|
|
if not hf_user and os.environ.get("HF_USER"): hf_user = os.environ.get("HF_USER")
|
|
if not hf_user: raise gr.Error(f"Invalid user name: {hf_user}")
|
|
if not hf_repo and os.environ.get("HF_REPO"): hf_repo = os.environ.get("HF_REPO")
|
|
if not is_overwrite and os.environ.get("HF_OW"): is_overwrite = os.environ.get("HF_OW")
|
|
if not dl_url and os.environ.get("HF_URL"): dl_url = os.environ.get("HF_URL")
|
|
set_token(hf_token)
|
|
lora_dict = {lora1: lora1s, lora2: lora2s, lora3: lora3s, lora4: lora4s, lora5: lora5s}
|
|
new_path = convert_url_to_diffusers(dl_url, civitai_key, is_upload_sf, dtype, vae, clip, t5, scheduler, ema, base_repo, mtype, lora_dict, is_local)
|
|
if not new_path: return ""
|
|
new_repo_id = f"{hf_user}/{Path(new_path).stem}"
|
|
if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
|
|
if not is_repo_name(new_repo_id): raise gr.Error(f"Invalid repo name: {new_repo_id}")
|
|
if not is_overwrite and is_repo_exists(new_repo_id): raise gr.Error(f"Repo already exists: {new_repo_id}")
|
|
repo_url = upload_repo(new_repo_id, new_path, is_private)
|
|
safe_clean(new_path)
|
|
if not urls: urls = []
|
|
urls.append(repo_url)
|
|
md = "### Your new repo:\n"
|
|
for u in urls:
|
|
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
|
|
return gr.update(value=urls, choices=urls), gr.update(value=md)
|
|
except Exception as e:
|
|
print(f"Error occured. {e}")
|
|
raise gr.Error(f"Error occured. {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--url", type=str, required=True, help="URL of the model to convert.")
|
|
parser.add_argument("--dtype", default="fp16", type=str, choices=get_dtypes(), help='Output data type. (Default: "fp16")')
|
|
parser.add_argument("--scheduler", default="Euler a", type=str, choices=list(SCHEDULER_CONFIG_MAP.keys()), required=False, help="Scheduler name to use.")
|
|
parser.add_argument("--vae", default="", type=str, required=False, help="URL or Repo ID of the VAE to use.")
|
|
parser.add_argument("--clip", default="", type=str, required=False, help="URL or Repo ID of the CLIP to use.")
|
|
parser.add_argument("--t5", default="", type=str, required=False, help="URL or Repo ID of the T5 to use.")
|
|
parser.add_argument("--base", default="", type=str, required=False, help="Repo ID of the base repo.")
|
|
parser.add_argument("--nonema", action="store_true", default=False, help="Don't extract EMA (for SD 1.5).")
|
|
parser.add_argument("--civitai_key", default="", type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).")
|
|
parser.add_argument("--lora1", default="", type=str, required=False, help="URL of the LoRA to use.")
|
|
parser.add_argument("--lora1s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora1.")
|
|
parser.add_argument("--lora2", default="", type=str, required=False, help="URL of the LoRA to use.")
|
|
parser.add_argument("--lora2s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora2.")
|
|
parser.add_argument("--lora3", default="", type=str, required=False, help="URL of the LoRA to use.")
|
|
parser.add_argument("--lora3s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora3.")
|
|
parser.add_argument("--lora4", default="", type=str, required=False, help="URL of the LoRA to use.")
|
|
parser.add_argument("--lora4s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora4.")
|
|
parser.add_argument("--lora5", default="", type=str, required=False, help="URL of the LoRA to use.")
|
|
parser.add_argument("--lora5s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora5.")
|
|
parser.add_argument("--loras", default="", type=str, required=False, help="Folder of the LoRA to use.")
|
|
|
|
args = parser.parse_args()
|
|
assert args.url is not None, "Must provide a URL!"
|
|
|
|
is_local = True
|
|
lora_dict = {args.lora1: args.lora1s, args.lora2: args.lora2s, args.lora3: args.lora3s, args.lora4: args.lora4s, args.lora5: args.lora5s}
|
|
if args.loras and Path(args.loras).exists():
|
|
for p in Path(args.loras).glob('**/*.safetensors'):
|
|
lora_dict[str(p)] = 1.0
|
|
ema = not args.nonema
|
|
mtype = "SDXL"
|
|
|
|
convert_url_to_diffusers(args.url, args.civitai_key, args.dtype, args.vae, args.clip, args.t5, args.scheduler, ema, args.base, mtype, lora_dict, is_local)
|
|
|