File size: 8,149 Bytes
05fd390
4efa9a6
05fd390
9db9711
05fd390
 
62b04c4
463536c
9db9711
 
 
 
 
 
 
 
 
 
 
 
4efa9a6
05fd390
 
 
 
 
 
 
9db9711
05fd390
 
 
 
 
 
 
 
 
 
118c8fd
05fd390
463536c
05fd390
 
 
 
 
 
 
 
 
 
463536c
 
05fd390
62b04c4
 
05fd390
 
 
 
 
9db9711
 
05fd390
 
9db9711
 
 
 
 
 
 
 
 
 
 
05fd390
463536c
383a495
463536c
 
9db9711
 
 
 
 
 
05fd390
9db9711
 
 
 
 
463536c
05fd390
 
 
 
 
8f30316
 
 
 
05fd390
 
 
 
 
 
 
4efa9a6
05fd390
 
 
 
 
f559d19
 
9db9711
05fd390
463536c
05fd390
 
 
9db9711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463536c
fc27a96
 
 
 
 
 
 
9db9711
05fd390
9db9711
 
 
05fd390
86cbf7f
 
463536c
 
 
118c8fd
463536c
9db9711
 
463536c
86cbf7f
 
05fd390
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import gradio as gr
import os
import subprocess
from subprocess import getoutput
from huggingface_hub import snapshot_download

hf_token = os.environ.get("HF_TOKEN")

is_shared_ui = True if "fffiloni/train-dreambooth-lora-sdxl" in os.environ['SPACE_ID'] else False


is_gpu_associated = torch.cuda.is_available()
if is_gpu_associated:
    gpu_info = getoutput('nvidia-smi')
    if("A10G" in gpu_info):
        which_gpu = "A10G"
    elif("T4" in gpu_info):
        which_gpu = "T4"
    else:
        which_gpu = "CPU"

def set_accelerate_default_config():
    try:
        subprocess.run(["accelerate", "config", "default"], check=True)
        print("Accelerate default config set successfully!")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")

def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu):
    
    script_filename = "train_dreambooth_lora_sdxl.py"  # Assuming it's in the same folder

    command = [
        "accelerate",
        "launch",
        script_filename,  # Use the local script
        "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
        "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
        f"--instance_data_dir={instance_data_dir}",
        f"--output_dir={lora_trained_xl_folder}",
        "--mixed_precision=fp16",
        f"--instance_prompt={instance_prompt}",
        "--resolution=1024",
        "--train_batch_size=2",
        "--gradient_accumulation_steps=2",
        "--gradient_checkpointing",
        "--learning_rate=1e-4",
        "--lr_scheduler=constant",
        "--lr_warmup_steps=0",
        "--enable_xformers_memory_efficient_attention",
        "--mixed_precision=fp16",
        "--use_8bit_adam",
        f"--max_train_steps={max_train_steps}",
        f"--checkpointing_steps={checkpoint_steps}",
        "--seed=0",
        "--push_to_hub",
        f"--hub_token={hf_token}"
    ]

    try:
        subprocess.run(command, check=True)
        print("Training is finished!")
        if remove_gpu:
            swap_hardware(hf_token, "cpu-basic")
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")
        
        title="There was an error on during your training"
        description=f'''
        Unfortunately there was an error during training your {model_name} model. 
        Please check it out below. Feel free to report this issue to [SD-XL Dreambooth LoRa Training](https://huggingface.co/spaces/fffiloni/train-dreambooth-lora-sdxl): 
        ```
        {str(e)}
        ```
        '''
        swap_hardware(hf_token, "cpu-basic")
        write_to_community(title,description,hf_token)

def main(dataset_id, 
         lora_trained_xl_folder,
         instance_prompt,
         max_train_steps,
         checkpoint_steps,
         remove_gpu):


    if is_shared_ui:
        raise gr.Error("This Space only works in duplicated instances")

    if not is_gpu_associated:
        raise gr.Error("Please associate a T4 or A10G GPU for this Space")

    gr.Warning("## Training is ongoing βŒ›... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ")
        
    dataset_repo = dataset_id

    # Automatically set local_dir based on the last part of dataset_repo
    repo_parts = dataset_repo.split("/")
    local_dir = f"./{repo_parts[-1]}"  # Use the last part of the split

    # Check if the directory exists and create it if necessary
    if not os.path.exists(local_dir):
        os.makedirs(local_dir)

    gr.Info("Downloading dataset ...")
    
    snapshot_download(
        dataset_repo,
        local_dir=local_dir,
        repo_type="dataset",
        ignore_patterns=".gitattributes",
        token=hf_token
    )

    set_accelerate_default_config()

    gr.Info("Training begins ...")

    instance_data_dir = repo_parts[-1]
    train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)

    return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}"

with gr.Blocks() as demo:
    with gr.Column():
        if is_shared_ui:
            top_description = gr.HTML(f'''
                <div class="gr-prose" style="max-width: 80%">
                <h2>Attention - This Space doesn't work in this shared UI</h2>
                <p>For it to work, you can duplicate the Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training. A T4 costs US$0.60/h, so it should cost < US$1 to train most models using default settings with it!&nbsp;&nbsp;<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></p>
                <img class="instruction" src="file=duplicate.png"> 
                <img class="arrow" src="file=arrow.png" />
                </div>
            ''')
        else:
            if(is_gpu_associated):
                top_description = gr.HTML(f'''
                        <div class="gr-prose" style="max-width: 80%">
                        <h2>You have successfully associated a {which_gpu} GPU to the SD-XL Dreambooth LoRa Training Space πŸŽ‰</h2>
                        <p>You can now train your model! You will be billed by the minute from when you activated the GPU until when it is turned it off.</p> 
                        </div>
                ''')
            else:
                top_description = gr.HTML(f'''
                        <div class="gr-prose" style="max-width: 80%">
                        <h2>You have successfully duplicated the SD-XL Dreambooth LoRa Training Space πŸŽ‰</h2>
                        <p>There's only one step left before you can train your model: <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}/settings" style="text-decoration: underline" target="_blank">attribute a <b>T4-small or A10G-small GPU</b> to it (via the Settings tab)</a> and run the training below. You will be billed by the minute from when you activate the GPU until when it is turned it off.</p> 
                        </div>
                ''')
        with gr.Row():
            dataset_id = gr.Textbox(label="Dataset ID", info="use one of your previously uploaded datasets on your HF profile", placeholder="diffusers/dog-example")
            instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
        
        with gr.Row():
            model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
            max_train_steps = gr.Number(label="Max Training Steps", value=500)
            checkpoint_steps = gr.Number(label="Checkpoints Steps", value=100)
            remove_gpu = gr.Checkbox(label="Remove GPU After Training", value=True)
        train_button = gr.Button("Train !")

        
        status = gr.Textbox(label="Training status")

    train_button.click(
        fn = main,
        inputs = [
            dataset_id,
            model_output_folder,
            instance_prompt,
            max_train_steps,
            checkpoint_steps,
            remove_gpu
        ],
        outputs = [status]
    )

demo.queue().launch()