|
import gradio as gr |
|
import os |
|
import subprocess |
|
from huggingface_hub import snapshot_download |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
def set_accelerate_default_config(): |
|
try: |
|
subprocess.run(["accelerate", "config", "default"], check=True) |
|
print("Accelerate default config set successfully!") |
|
except subprocess.CalledProcessError as e: |
|
print(f"An error occurred: {e}") |
|
|
|
def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps): |
|
|
|
script_filename = "train_dreambooth_lora_sdxl.py" |
|
|
|
command = [ |
|
"accelerate", |
|
"launch", |
|
script_filename, |
|
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0", |
|
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix", |
|
f"--instance_data_dir={instance_data_dir}", |
|
f"--output_dir={lora_trained_xl_folder}", |
|
"--mixed_precision=fp16", |
|
f"--instance_prompt={instance_prompt}", |
|
"--resolution=1024", |
|
"--train_batch_size=2", |
|
"--gradient_accumulation_steps=2", |
|
"--gradient_checkpointing", |
|
"--learning_rate=1e-4", |
|
"--lr_scheduler=constant", |
|
"--lr_warmup_steps=0", |
|
"--enable_xformers_memory_efficient_attention", |
|
"--mixed_precision=fp16", |
|
"--use_8bit_adam", |
|
f"--max_train_steps={max_train_steps}", |
|
f"--checkpointing_steps={checkpoint_steps}", |
|
"--seed=0", |
|
"--push_to_hub", |
|
f"--hub_token={hf_token}" |
|
] |
|
|
|
try: |
|
subprocess.run(command, check=True) |
|
print("Training is finished!") |
|
except subprocess.CalledProcessError as e: |
|
print(f"An error occurred: {e}") |
|
|
|
def main(dataset_id, |
|
lora_trained_xl_folder, |
|
instance_prompt, |
|
max_train_steps, |
|
checkpoint_steps): |
|
|
|
dataset_repo = dataset_id |
|
|
|
|
|
repo_parts = dataset_repo.split("/") |
|
local_dir = f"./{repo_parts[-1]}" |
|
|
|
|
|
if not os.path.exists(local_dir): |
|
os.makedirs(local_dir) |
|
|
|
gr.Info("Downloading dataset ...") |
|
|
|
snapshot_download( |
|
dataset_repo, |
|
local_dir=local_dir, |
|
repo_type="dataset", |
|
ignore_patterns=".gitattributes", |
|
token=hf_token |
|
) |
|
|
|
set_accelerate_default_config() |
|
|
|
gr.Info("Training begins ...") |
|
|
|
instance_data_dir = repo_parts[-1] |
|
train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps) |
|
|
|
return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}" |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
dataset_id = gr.Textbox(label="Dataset ID", placeholder="diffusers/dog-example") |
|
instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions") |
|
model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder") |
|
with gr.Row(): |
|
max_train_steps = gr.Number(value=500) |
|
checkpoint_steps = gr.Number(value=100) |
|
train_button = gr.Button("Train !") |
|
status = gr.Textbox(labe="Training status") |
|
|
|
train_button.click( |
|
fn = main, |
|
inputs = [ |
|
dataset_id, |
|
model_output_folder, |
|
instance_prompt, |
|
max_train_steps, |
|
checkpoint_steps |
|
], |
|
outputs = [status] |
|
) |
|
|
|
demo.queue().launch() |