Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import subprocess | |
from huggingface_hub import snapshot_download | |
hf_token = os.environ.get("HF_TOKEN") | |
print(hf_token) | |
def set_accelerate_default_config(): | |
try: | |
subprocess.run(["accelerate", "config", "default"], check=True) | |
print("Accelerate default config set successfully!") | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred: {e}") | |
def train_dreambooth_lora_sdxl(instance_data_dir): | |
script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder | |
command = [ | |
"accelerate", | |
"launch", | |
script_filename, # Use the local script | |
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0", | |
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix", | |
f"--instance_data_dir={instance_data_dir}", | |
"--output_dir=lora-trained-xl-colab_2", | |
"--mixed_precision=fp16", | |
"--instance_prompt=egnestl", | |
"--resolution=1024", | |
"--train_batch_size=2", | |
"--gradient_accumulation_steps=2", | |
"--gradient_checkpointing", | |
"--learning_rate=1e-4", | |
"--lr_scheduler=constant", | |
"--lr_warmup_steps=0", | |
"--enable_xformers_memory_efficient_attention", | |
"--mixed_precision=fp16", | |
"--use_8bit_adam", | |
"--max_train_steps=25", | |
"--checkpointing_steps=717", | |
"--seed=0", | |
"--push_to_hub", | |
f"--hub_token={hf_token}" | |
] | |
try: | |
subprocess.run(command, check=True) | |
print("Training is finished!") | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred: {e}") | |
def main(dataset_url): | |
dataset_repo = dataset_url | |
# Automatically set local_dir based on the last part of dataset_repo | |
repo_parts = dataset_repo.split("/") | |
local_dir = f"./{repo_parts[-1]}" # Use the last part of the split | |
# Check if the directory exists and create it if necessary | |
if not os.path.exists(local_dir): | |
os.makedirs(local_dir) | |
gr.Info("Downloading dataset ...") | |
snapshot_download( | |
dataset_repo, | |
local_dir=local_dir, | |
repo_type="dataset", | |
ignore_patterns=".gitattributes", | |
token=hf_token | |
) | |
set_accelerate_default_config() | |
gr.Info("Training begins ...") | |
instance_data_dir = repo_parts[-1] | |
train_dreambooth_lora_sdxl(instance_data_dir) | |
return "Done" | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
dataset_id = gr.Textbox(label="Dataset ID") | |
train_button = gr.Button("Train !") | |
status = gr.Textbox(labe="Training status") | |
train_button.click( | |
fn = main, | |
inputs = [dataset_id], | |
outputs = [status] | |
) | |
demo.queue().launch() |