File size: 2,796 Bytes
05fd390
4efa9a6
05fd390
 
 
62b04c4
4efa9a6
 
05fd390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceaadfa
05fd390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62b04c4
 
05fd390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f30316
 
 
 
05fd390
 
 
 
 
 
 
4efa9a6
05fd390
 
 
 
 
f559d19
 
 
05fd390
 
 
 
 
 
 
 
 
86cbf7f
 
 
 
 
05fd390
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import os
import subprocess
from huggingface_hub import snapshot_download

hf_token = os.environ.get("HF_TOKEN")
print(hf_token)

def set_accelerate_default_config():
    try:
        subprocess.run(["accelerate", "config", "default"], check=True)
        print("Accelerate default config set successfully!")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

def train_dreambooth_lora_sdxl(instance_data_dir):
    
    script_filename = "train_dreambooth_lora_sdxl.py"  # Assuming it's in the same folder

    command = [
        "accelerate",
        "launch",
        script_filename,  # Use the local script
        "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
        "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
        f"--instance_data_dir={instance_data_dir}",
        "--output_dir=lora-trained-xl-colab_2",
        "--mixed_precision=fp16",
        "--instance_prompt=egnestl",
        "--resolution=1024",
        "--train_batch_size=2",
        "--gradient_accumulation_steps=2",
        "--gradient_checkpointing",
        "--learning_rate=1e-4",
        "--lr_scheduler=constant",
        "--lr_warmup_steps=0",
        "--enable_xformers_memory_efficient_attention",
        "--mixed_precision=fp16",
        "--use_8bit_adam",
        "--max_train_steps=25",
        "--checkpointing_steps=717",
        "--seed=0",
        "--push_to_hub",
        f"--hub_token={hf_token}"
    ]

    try:
        subprocess.run(command, check=True)
        print("Training is finished!")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

def main(dataset_url):

    dataset_repo = dataset_url

    # Automatically set local_dir based on the last part of dataset_repo
    repo_parts = dataset_repo.split("/")
    local_dir = f"./{repo_parts[-1]}"  # Use the last part of the split

    # Check if the directory exists and create it if necessary
    if not os.path.exists(local_dir):
        os.makedirs(local_dir)

    gr.Info("Downloading dataset ...")
    
    snapshot_download(
        dataset_repo,
        local_dir=local_dir,
        repo_type="dataset",
        ignore_patterns=".gitattributes",
        token=hf_token
    )

    set_accelerate_default_config()

    gr.Info("Training begins ...")

    instance_data_dir = repo_parts[-1]
    train_dreambooth_lora_sdxl(instance_data_dir)

    return "Done"

with gr.Blocks() as demo:
    with gr.Column():
        dataset_id = gr.Textbox(label="Dataset ID")
        train_button = gr.Button("Train !")
        status = gr.Textbox(labe="Training status")

    train_button.click(
        fn = main,
        inputs = [dataset_id],
        outputs = [status]
    )

demo.queue().launch()