fffiloni's picture
Update app.py
383a495
raw
history blame
No virus
3.74 kB
import gradio as gr
import os
import subprocess
from huggingface_hub import snapshot_download
hf_token = os.environ.get("HF_TOKEN")
def set_accelerate_default_config():
try:
subprocess.run(["accelerate", "config", "default"], check=True)
print("Accelerate default config set successfully!")
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps):
script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
command = [
"accelerate",
"launch",
script_filename, # Use the local script
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
f"--instance_data_dir={instance_data_dir}",
f"--output_dir={lora_trained_xl_folder}",
"--mixed_precision=fp16",
f"--instance_prompt={instance_prompt}",
"--resolution=1024",
"--train_batch_size=2",
"--gradient_accumulation_steps=2",
"--gradient_checkpointing",
"--learning_rate=1e-4",
"--lr_scheduler=constant",
"--lr_warmup_steps=0",
"--enable_xformers_memory_efficient_attention",
"--mixed_precision=fp16",
"--use_8bit_adam",
f"--max_train_steps={max_train_steps}",
f"--checkpointing_steps={checkpoint_steps}",
"--seed=0",
"--push_to_hub",
f"--hub_token={hf_token}"
]
try:
subprocess.run(command, check=True)
print("Training is finished!")
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
def main(dataset_id,
lora_trained_xl_folder,
instance_prompt,
max_train_steps,
checkpoint_steps):
dataset_repo = dataset_id
# Automatically set local_dir based on the last part of dataset_repo
repo_parts = dataset_repo.split("/")
local_dir = f"./{repo_parts[-1]}" # Use the last part of the split
# Check if the directory exists and create it if necessary
if not os.path.exists(local_dir):
os.makedirs(local_dir)
gr.Info("Downloading dataset ...")
snapshot_download(
dataset_repo,
local_dir=local_dir,
repo_type="dataset",
ignore_patterns=".gitattributes",
token=hf_token
)
set_accelerate_default_config()
gr.Info("Training begins ...")
instance_data_dir = repo_parts[-1]
train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps)
return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}"
with gr.Blocks() as demo:
with gr.Column():
dataset_id = gr.Textbox(label="Dataset ID", placeholder="diffusers/dog-example")
instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
with gr.Row():
max_train_steps = gr.Number(value=500)
checkpoint_steps = gr.Number(value=100)
train_button = gr.Button("Train !")
status = gr.Textbox(labe="Training status")
train_button.click(
fn = main,
inputs = [
dataset_id,
model_output_folder,
instance_prompt,
max_train_steps,
checkpoint_steps
],
outputs = [status]
)
demo.queue().launch()