import io
import os
import shutil
import zipfile
import gradio as gr
import requests
from huggingface_hub import create_repo, upload_folder, whoami
from convert import convert_full_checkpoint
MODELS_DIR = "models/"
CKPT_FILE = MODELS_DIR + "model.ckpt"
HF_MODEL_DIR = MODELS_DIR + "diffusers_model"
ZIP_FILE = MODELS_DIR + "model.zip"
def download_ckpt(url, out_path):
with open(out_path, "wb") as out_file:
with requests.get(url, stream=True) as r:
r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
out_file.write(chunk)
def zip_model(model_path, zip_path):
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as zip_file:
for root, dirs, files in os.walk(model_path):
for file in files:
zip_file.write(
os.path.join(root, file),
os.path.relpath(
os.path.join(root, file), os.path.join(model_path, "..")
),
)
def download_checkpoint_and_config(ckpt_url, config_url):
ckpt_url = ckpt_url.strip()
config_url = config_url.strip()
if not ckpt_url.startswith("http://") and not ckpt_url.startswith("https://"):
raise ValueError("Invalid checkpoint URL")
if config_url.startswith("http://") or config_url.startswith("https://"):
response = requests.get(config_url)
response.raise_for_status()
config_file = io.BytesIO(response.content)
elif config_url != "":
raise ValueError("Invalid config URL")
else:
config_file = open("original_config.yaml", "r")
download_ckpt(ckpt_url, CKPT_FILE)
return CKPT_FILE, config_file
def convert_and_download(ckpt_url, config_url, scheduler_type, extract_ema):
shutil.rmtree(MODELS_DIR, ignore_errors=True)
os.makedirs(HF_MODEL_DIR)
ckpt_path, config_file = download_checkpoint_and_config(ckpt_url, config_url)
convert_full_checkpoint(
ckpt_path,
config_file,
scheduler_type=scheduler_type,
extract_ema=(extract_ema == "EMA"),
output_path=HF_MODEL_DIR,
)
zip_model(HF_MODEL_DIR, ZIP_FILE)
return ZIP_FILE
def convert_and_upload(
ckpt_url, config_url, scheduler_type, extract_ema, token, model_name
):
shutil.rmtree(MODELS_DIR, ignore_errors=True)
os.makedirs(HF_MODEL_DIR)
try:
ckpt_path, config_file = download_checkpoint_and_config(ckpt_url, config_url)
username = whoami(token)["name"]
repo_name = f"{username}/{model_name}"
repo_url = create_repo(repo_name, token=token, exist_ok=True)
convert_full_checkpoint(
ckpt_path,
config_file,
scheduler_type=scheduler_type,
extract_ema=(extract_ema == "EMA"),
output_path=HF_MODEL_DIR,
)
upload_folder(repo_id=repo_name, folder_path=HF_MODEL_DIR, token=token, commit_message=f"Upload diffusers weights")
except Exception as e:
return f"#### Error: {e}"
return f"#### Success! Model uploaded to [{repo_url}]({repo_url})"
TTILE_IMAGE = """
"""
TITLE = """
Convert Stable Diffusion `.ckpt` files to Hugging Face Diffusers 🔥
"""
with gr.Blocks() as interface:
gr.HTML(TTILE_IMAGE)
gr.HTML(TITLE)
gr.Markdown("We will perform all of the checkpoint surgery for you, and create a clean diffusers model!")
gr.Markdown("This converter will also remove any pickled code from third-party checkpoints.")
with gr.Row():
with gr.Column(scale=50):
gr.Markdown("### 1. Paste a URL to your .ckpt file")
ckpt_url = gr.Textbox(
max_lines=1,
label="URL to .ckpt",
placeholder="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt",
)
with gr.Column(scale=50):
gr.Markdown("### (Optional) paste a URL to your .yaml file")
config_url = gr.Textbox(
max_lines=1,
label="URL to .yaml",
placeholder="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-inference.yaml",
)
gr.Markdown(
"**If you don't provide a config file, we'll try to use"
" [v1-inference.yaml](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-inference.yaml).*"
)
with gr.Accordion("Advanced Settings"):
scheduler_type = gr.Dropdown(
label="Choose a scheduler type (if not sure, keep the PNDM default)",
choices=["PNDM", "K-LMS", "Euler", "EulerAncestral", "DDIM"],
value="PNDM",
)
extract_ema = gr.Radio(
label=(
"EMA weights usually yield higher quality images for inference."
" Non-EMA weights are usually better to continue fine-tuning."
),
choices=["EMA", "Non-EMA"],
value="EMA",
interactive=True,
)
gr.Markdown("### 2. Choose what to do with the converted model")
model_choice = gr.Radio(
show_label=False,
choices=[
"Download the model as an archive",
"Host the model on the Hugging Face Hub",
# "Submit a PR with the model for an existing Hub repository",
],
type="index",
value="Download the model as an archive",
interactive=True,
)
download_panel = gr.Column(visible=True)
upload_panel = gr.Column(visible=False)
# pr_panel = gr.Column(visible=False)
model_choice.change(
fn=lambda i: gr.update(visible=(i == 0)),
inputs=model_choice,
outputs=download_panel,
)
model_choice.change(
fn=lambda i: gr.update(visible=(i == 1)),
inputs=model_choice,
outputs=upload_panel,
)
# model_choice.change(
# fn=lambda i: gr.update(visible=(i == 2)),
# inputs=model_choice,
# outputs=pr_panel,
# )
with download_panel:
gr.Markdown("### 3. Convert and download")
down_btn = gr.Button("Convert")
output_file = gr.File(
label="Download the converted model",
type="binary",
interactive=False,
visible=True,
)
down_btn.click(
fn=convert_and_download,
inputs=[ckpt_url, config_url, scheduler_type, extract_ema],
outputs=output_file,
)
with upload_panel:
gr.Markdown("### 3. Convert and host on the Hub")
gr.Markdown(
"This will create a new repository if it doesn't exist yet, and upload the model to the Hugging Face Hub.\n\n"
"Paste a WRITE token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)"
" and make up a model name."
)
up_token = gr.Textbox(
max_lines=1,
label="Hugging Face token",
)
up_model_name = gr.Textbox(
max_lines=1,
label="Hub model name (e.g. `artistic-diffusion-v1`)",
placeholder="my-awesome-model",
)
upload_btn = gr.Button("Convert and upload")
with gr.Box():
output_text = gr.Markdown()
upload_btn.click(
fn=convert_and_upload,
inputs=[
ckpt_url,
config_url,
scheduler_type,
extract_ema,
up_token,
up_model_name,
],
outputs=output_text,
)
# with pr_panel:
# gr.Markdown("### 3. Convert and submit as a PR")
# gr.Markdown(
# "This will open a Pull Request on the original model repository, if it already exists on the Hub.\n\n"
# "Paste a write-access token from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)"
# " and paste an existing model id from the Hub in the `username/model-name` form."
# )
# pr_token = gr.Textbox(
# max_lines=1,
# label="Hugging Face token",
# )
# pr_model_name = gr.Textbox(
# max_lines=1,
# label="Hub model name (e.g. `diffuser/artistic-diffusion-v1`)",
# placeholder="diffuser/my-awesome-model",
# )
#
# btn = gr.Button("Convert and open a PR")
# output = gr.Markdown(label="Output")
interface.queue(concurrency_count=1)
interface.launch()