fffiloni commited on
Commit
463536c
1 Parent(s): fd13994

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -12
app.py CHANGED
@@ -4,7 +4,7 @@ import subprocess
4
  from huggingface_hub import snapshot_download
5
 
6
  hf_token = os.environ.get("HF_TOKEN")
7
- print(hf_token)
8
 
9
  def set_accelerate_default_config():
10
  try:
@@ -13,7 +13,7 @@ def set_accelerate_default_config():
13
  except subprocess.CalledProcessError as e:
14
  print(f"An error occurred: {e}")
15
 
16
- def train_dreambooth_lora_sdxl(instance_data_dir):
17
 
18
  script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
19
 
@@ -24,9 +24,9 @@ def train_dreambooth_lora_sdxl(instance_data_dir):
24
  "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
25
  "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
26
  f"--instance_data_dir={instance_data_dir}",
27
- "--output_dir=lora-trained-xl-colab_2",
28
  "--mixed_precision=fp16",
29
- "--instance_prompt=egnestl",
30
  "--resolution=1024",
31
  "--train_batch_size=2",
32
  "--gradient_accumulation_steps=2",
@@ -37,8 +37,8 @@ def train_dreambooth_lora_sdxl(instance_data_dir):
37
  "--enable_xformers_memory_efficient_attention",
38
  "--mixed_precision=fp16",
39
  "--use_8bit_adam",
40
- "--max_train_steps=25",
41
- "--checkpointing_steps=717",
42
  "--seed=0",
43
  "--push_to_hub",
44
  f"--hub_token={hf_token}"
@@ -50,9 +50,13 @@ def train_dreambooth_lora_sdxl(instance_data_dir):
50
  except subprocess.CalledProcessError as e:
51
  print(f"An error occurred: {e}")
52
 
53
- def main(dataset_url):
 
 
 
 
54
 
55
- dataset_repo = dataset_url
56
 
57
  # Automatically set local_dir based on the last part of dataset_repo
58
  repo_parts = dataset_repo.split("/")
@@ -77,19 +81,30 @@ def main(dataset_url):
77
  gr.Info("Training begins ...")
78
 
79
  instance_data_dir = repo_parts[-1]
80
- train_dreambooth_lora_sdxl(instance_data_dir)
81
 
82
- return "Done"
83
 
84
  with gr.Blocks() as demo:
85
  with gr.Column():
86
- dataset_id = gr.Textbox(label="Dataset ID")
 
 
 
 
 
87
  train_button = gr.Button("Train !")
88
  status = gr.Textbox(labe="Training status")
89
 
90
  train_button.click(
91
  fn = main,
92
- inputs = [dataset_id],
 
 
 
 
 
 
93
  outputs = [status]
94
  )
95
 
 
4
  from huggingface_hub import snapshot_download
5
 
6
  hf_token = os.environ.get("HF_TOKEN")
7
+
8
 
9
  def set_accelerate_default_config():
10
  try:
 
13
  except subprocess.CalledProcessError as e:
14
  print(f"An error occurred: {e}")
15
 
16
+ def train_dreambooth_lora_sdxl(instance_data_dir, lora-trained-xl-folder, instance_prompt, max_train_steps, checkpoint_steps):
17
 
18
  script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
19
 
 
24
  "--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0",
25
  "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
26
  f"--instance_data_dir={instance_data_dir}",
27
+ f"--output_dir={lora-trained-xl-folder}",
28
  "--mixed_precision=fp16",
29
+ f"--instance_prompt={instance_prompt}",
30
  "--resolution=1024",
31
  "--train_batch_size=2",
32
  "--gradient_accumulation_steps=2",
 
37
  "--enable_xformers_memory_efficient_attention",
38
  "--mixed_precision=fp16",
39
  "--use_8bit_adam",
40
+ f"--max_train_steps={max_train_steps}",
41
+ f"--checkpointing_steps={checkpoint_steps}",
42
  "--seed=0",
43
  "--push_to_hub",
44
  f"--hub_token={hf_token}"
 
50
  except subprocess.CalledProcessError as e:
51
  print(f"An error occurred: {e}")
52
 
53
+ def main(dataset_id,
54
+ lora-trained-xl-folder,
55
+ instance_prompt,
56
+ max_train_steps,
57
+ checkpoint_steps):
58
 
59
+ dataset_repo = dataset_id
60
 
61
  # Automatically set local_dir based on the last part of dataset_repo
62
  repo_parts = dataset_repo.split("/")
 
81
  gr.Info("Training begins ...")
82
 
83
  instance_data_dir = repo_parts[-1]
84
+ train_dreambooth_lora_sdxl(instance_data_dir, lora-trained-xl-folder, instance_prompt, max_train_steps, checkpoint_steps)
85
 
86
+ return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}"
87
 
88
  with gr.Blocks() as demo:
89
  with gr.Column():
90
+ dataset_id = gr.Textbox(label="Dataset ID", placeholder="diffusers/dog-example")
91
+ instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
92
+ model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
93
+ with gr.Row():
94
+ max_train_steps = gr.Number(value=500)
95
+ checkpoint_steps = gr.Number(value=100)
96
  train_button = gr.Button("Train !")
97
  status = gr.Textbox(labe="Training status")
98
 
99
  train_button.click(
100
  fn = main,
101
+ inputs = [
102
+ dataset_id,
103
+ instance_prompt,
104
+ model_output_folder,
105
+ max_train_steps,
106
+ checkpoint_steps
107
+ ],
108
  outputs = [status]
109
  )
110