Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import subprocess
|
|
4 |
from huggingface_hub import snapshot_download
|
5 |
|
6 |
hf_token = os.environ.get("HF_TOKEN")
|
7 |
-
|
8 |
|
9 |
def set_accelerate_default_config():
|
10 |
try:
|
@@ -13,7 +13,7 @@ def set_accelerate_default_config():
|
|
13 |
except subprocess.CalledProcessError as e:
|
14 |
print(f"An error occurred: {e}")
|
15 |
|
16 |
-
def train_dreambooth_lora_sdxl(instance_data_dir):
|
17 |
|
18 |
script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
|
19 |
|
@@ -24,9 +24,9 @@ def train_dreambooth_lora_sdxl(instance_data_dir):
|
|
24 |
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
|
25 |
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
|
26 |
f"--instance_data_dir={instance_data_dir}",
|
27 |
-
"--output_dir=lora-trained-xl-
|
28 |
"--mixed_precision=fp16",
|
29 |
-
"--instance_prompt=
|
30 |
"--resolution=1024",
|
31 |
"--train_batch_size=2",
|
32 |
"--gradient_accumulation_steps=2",
|
@@ -37,8 +37,8 @@ def train_dreambooth_lora_sdxl(instance_data_dir):
|
|
37 |
"--enable_xformers_memory_efficient_attention",
|
38 |
"--mixed_precision=fp16",
|
39 |
"--use_8bit_adam",
|
40 |
-
"--max_train_steps=
|
41 |
-
"--checkpointing_steps=
|
42 |
"--seed=0",
|
43 |
"--push_to_hub",
|
44 |
f"--hub_token={hf_token}"
|
@@ -50,9 +50,13 @@ def train_dreambooth_lora_sdxl(instance_data_dir):
|
|
50 |
except subprocess.CalledProcessError as e:
|
51 |
print(f"An error occurred: {e}")
|
52 |
|
53 |
-
def main(
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
dataset_repo =
|
56 |
|
57 |
# Automatically set local_dir based on the last part of dataset_repo
|
58 |
repo_parts = dataset_repo.split("/")
|
@@ -77,19 +81,30 @@ def main(dataset_url):
|
|
77 |
gr.Info("Training begins ...")
|
78 |
|
79 |
instance_data_dir = repo_parts[-1]
|
80 |
-
train_dreambooth_lora_sdxl(instance_data_dir)
|
81 |
|
82 |
-
return "Done"
|
83 |
|
84 |
with gr.Blocks() as demo:
|
85 |
with gr.Column():
|
86 |
-
dataset_id = gr.Textbox(label="Dataset ID")
|
|
|
|
|
|
|
|
|
|
|
87 |
train_button = gr.Button("Train !")
|
88 |
status = gr.Textbox(labe="Training status")
|
89 |
|
90 |
train_button.click(
|
91 |
fn = main,
|
92 |
-
inputs = [
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
outputs = [status]
|
94 |
)
|
95 |
|
|
|
4 |
from huggingface_hub import snapshot_download
|
5 |
|
6 |
hf_token = os.environ.get("HF_TOKEN")
|
7 |
+
|
8 |
|
9 |
def set_accelerate_default_config():
|
10 |
try:
|
|
|
13 |
except subprocess.CalledProcessError as e:
|
14 |
print(f"An error occurred: {e}")
|
15 |
|
16 |
+
def train_dreambooth_lora_sdxl(instance_data_dir, lora-trained-xl-folder, instance_prompt, max_train_steps, checkpoint_steps):
|
17 |
|
18 |
script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
|
19 |
|
|
|
24 |
"--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
|
25 |
"--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
|
26 |
f"--instance_data_dir={instance_data_dir}",
|
27 |
+
f"--output_dir={lora-trained-xl-folder}",
|
28 |
"--mixed_precision=fp16",
|
29 |
+
f"--instance_prompt={instance_prompt}",
|
30 |
"--resolution=1024",
|
31 |
"--train_batch_size=2",
|
32 |
"--gradient_accumulation_steps=2",
|
|
|
37 |
"--enable_xformers_memory_efficient_attention",
|
38 |
"--mixed_precision=fp16",
|
39 |
"--use_8bit_adam",
|
40 |
+
f"--max_train_steps={max_train_steps}",
|
41 |
+
f"--checkpointing_steps={checkpoint_steps}",
|
42 |
"--seed=0",
|
43 |
"--push_to_hub",
|
44 |
f"--hub_token={hf_token}"
|
|
|
50 |
except subprocess.CalledProcessError as e:
|
51 |
print(f"An error occurred: {e}")
|
52 |
|
53 |
+
def main(dataset_id,
|
54 |
+
lora-trained-xl-folder,
|
55 |
+
instance_prompt,
|
56 |
+
max_train_steps,
|
57 |
+
checkpoint_steps):
|
58 |
|
59 |
+
dataset_repo = dataset_id
|
60 |
|
61 |
# Automatically set local_dir based on the last part of dataset_repo
|
62 |
repo_parts = dataset_repo.split("/")
|
|
|
81 |
gr.Info("Training begins ...")
|
82 |
|
83 |
instance_data_dir = repo_parts[-1]
|
84 |
+
train_dreambooth_lora_sdxl(instance_data_dir, lora-trained-xl-folder, instance_prompt, max_train_steps, checkpoint_steps)
|
85 |
|
86 |
+
return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}"
|
87 |
|
88 |
with gr.Blocks() as demo:
|
89 |
with gr.Column():
|
90 |
+
dataset_id = gr.Textbox(label="Dataset ID", placeholder="diffusers/dog-example")
|
91 |
+
instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
|
92 |
+
model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
|
93 |
+
with gr.Row():
|
94 |
+
max_train_steps = gr.Number(value=500)
|
95 |
+
checkpoint_steps = gr.Number(value=100)
|
96 |
train_button = gr.Button("Train !")
|
97 |
status = gr.Textbox(labe="Training status")
|
98 |
|
99 |
train_button.click(
|
100 |
fn = main,
|
101 |
+
inputs = [
|
102 |
+
dataset_id,
|
103 |
+
instance_prompt,
|
104 |
+
model_output_folder,
|
105 |
+
max_train_steps,
|
106 |
+
checkpoint_steps
|
107 |
+
],
|
108 |
outputs = [status]
|
109 |
)
|
110 |
|