fffiloni commited on
Commit
9db9711
1 Parent(s): fc27a96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -5
app.py CHANGED
@@ -1,10 +1,23 @@
1
  import gradio as gr
2
  import os
3
  import subprocess
 
4
  from huggingface_hub import snapshot_download
5
 
6
  hf_token = os.environ.get("HF_TOKEN")
7
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def set_accelerate_default_config():
10
  try:
@@ -13,7 +26,7 @@ def set_accelerate_default_config():
13
  except subprocess.CalledProcessError as e:
14
  print(f"An error occurred: {e}")
15
 
16
- def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps):
17
 
18
  script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
19
 
@@ -47,15 +60,38 @@ def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instan
47
  try:
48
  subprocess.run(command, check=True)
49
  print("Training is finished!")
 
 
50
  except subprocess.CalledProcessError as e:
51
  print(f"An error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def main(dataset_id,
54
  lora_trained_xl_folder,
55
  instance_prompt,
56
  max_train_steps,
57
- checkpoint_steps):
 
 
 
 
 
58
 
 
 
 
 
 
59
  dataset_repo = dataset_id
60
 
61
  # Automatically set local_dir based on the last part of dataset_repo
@@ -81,12 +117,36 @@ def main(dataset_id,
81
  gr.Info("Training begins ...")
82
 
83
  instance_data_dir = repo_parts[-1]
84
- train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps)
85
 
86
  return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}"
87
 
88
  with gr.Blocks() as demo:
89
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  with gr.Row():
91
  dataset_id = gr.Textbox(label="Dataset ID", info="use one of your previously uploaded datasets on your HF profile", placeholder="diffusers/dog-example")
92
  instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
@@ -95,8 +155,11 @@ with gr.Blocks() as demo:
95
  model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
96
  max_train_steps = gr.Number(label="Max Training Steps", value=500)
97
  checkpoint_steps = gr.Number(label="Checkpoints Steps", value=100)
 
98
  train_button = gr.Button("Train !")
99
- status = gr.Textbox(labe="Training status")
 
 
100
 
101
  train_button.click(
102
  fn = main,
@@ -105,7 +168,8 @@ with gr.Blocks() as demo:
105
  model_output_folder,
106
  instance_prompt,
107
  max_train_steps,
108
- checkpoint_steps
 
109
  ],
110
  outputs = [status]
111
  )
 
1
  import gradio as gr
2
  import os
3
  import subprocess
4
+ from subprocess import getoutput
5
  from huggingface_hub import snapshot_download
6
 
7
  hf_token = os.environ.get("HF_TOKEN")
8
 
9
+ is_shared_ui = True if "fffiloni/train-dreambooth-lora-sdxl" in os.environ['SPACE_ID'] else False
10
+
11
+
12
+ is_gpu_associated = torch.cuda.is_available()
13
+ if is_gpu_associated:
14
+ gpu_info = getoutput('nvidia-smi')
15
+ if("A10G" in gpu_info):
16
+ which_gpu = "A10G"
17
+ elif("T4" in gpu_info):
18
+ which_gpu = "T4"
19
+ else:
20
+ which_gpu = "CPU"
21
 
22
  def set_accelerate_default_config():
23
  try:
 
26
  except subprocess.CalledProcessError as e:
27
  print(f"An error occurred: {e}")
28
 
29
+ def train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu):
30
 
31
  script_filename = "train_dreambooth_lora_sdxl.py" # Assuming it's in the same folder
32
 
 
60
  try:
61
  subprocess.run(command, check=True)
62
  print("Training is finished!")
63
+ if remove_gpu:
64
+ swap_hardware(hf_token, "cpu-basic")
65
  except subprocess.CalledProcessError as e:
66
  print(f"An error occurred: {e}")
67
+
68
+ title="There was an error on during your training"
69
+ description=f'''
70
+ Unfortunately there was an error during training your {model_name} model.
71
+ Please check it out below. Feel free to report this issue to [SD-XL Dreambooth LoRa Training](https://huggingface.co/spaces/fffiloni/train-dreambooth-lora-sdxl):
72
+ ```
73
+ {str(e)}
74
+ ```
75
+ '''
76
+ swap_hardware(hf_token, "cpu-basic")
77
+ write_to_community(title,description,hf_token)
78
 
79
  def main(dataset_id,
80
  lora_trained_xl_folder,
81
  instance_prompt,
82
  max_train_steps,
83
+ checkpoint_steps,
84
+ remove_gpu):
85
+
86
+
87
+ if is_shared_ui:
88
+ raise gr.Error("This Space only works in duplicated instances")
89
 
90
+ if not is_gpu_associated:
91
+ raise gr.Error("Please associate a T4 or A10G GPU for this Space")
92
+
93
+ gr.Warning("## Training is ongoing ⌛... You can close this tab if you like or just wait. If you did not check the `Remove GPU After training`, you can come back here to try your model and upload it after training. Don't forget to remove the GPU attribution after you are done. ")
94
+
95
  dataset_repo = dataset_id
96
 
97
  # Automatically set local_dir based on the last part of dataset_repo
 
117
  gr.Info("Training begins ...")
118
 
119
  instance_data_dir = repo_parts[-1]
120
+ train_dreambooth_lora_sdxl(instance_data_dir, lora_trained_xl_folder, instance_prompt, max_train_steps, checkpoint_steps, remove_gpu)
121
 
122
  return f"Done, your trained model has been stored in your models library: your_user_name/{lora-trained-xl-folder}"
123
 
124
  with gr.Blocks() as demo:
125
  with gr.Column():
126
+ if is_shared_ui:
127
+ top_description = gr.HTML(f'''
128
+ <div class="gr-prose" style="max-width: 80%">
129
+ <h2>Attention - This Space doesn't work in this shared UI</h2>
130
+ <p>For it to work, you can duplicate the Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training. A T4 costs US$0.60/h, so it should cost < US$1 to train most models using default settings with it!&nbsp;&nbsp;<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></p>
131
+ <img class="instruction" src="file=duplicate.png">
132
+ <img class="arrow" src="file=arrow.png" />
133
+ </div>
134
+ ''')
135
+ else:
136
+ if(is_gpu_associated):
137
+ top_description = gr.HTML(f'''
138
+ <div class="gr-prose" style="max-width: 80%">
139
+ <h2>You have successfully associated a {which_gpu} GPU to the SD-XL Dreambooth LoRa Training Space 🎉</h2>
140
+ <p>You can now train your model! You will be billed by the minute from when you activated the GPU until when it is turned it off.</p>
141
+ </div>
142
+ ''')
143
+ else:
144
+ top_description = gr.HTML(f'''
145
+ <div class="gr-prose" style="max-width: 80%">
146
+ <h2>You have successfully duplicated the SD-XL Dreambooth LoRa Training Space 🎉</h2>
147
+ <p>There's only one step left before you can train your model: <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}/settings" style="text-decoration: underline" target="_blank">attribute a <b>T4-small or A10G-small GPU</b> to it (via the Settings tab)</a> and run the training below. You will be billed by the minute from when you activate the GPU until when it is turned it off.</p>
148
+ </div>
149
+ ''')
150
  with gr.Row():
151
  dataset_id = gr.Textbox(label="Dataset ID", info="use one of your previously uploaded datasets on your HF profile", placeholder="diffusers/dog-example")
152
  instance_prompt = gr.Textbox(label="Concept prompt", info="concept prompt - use a unique, made up word to avoid collisions")
 
155
  model_output_folder = gr.Textbox(label="Output model folder name", placeholder="lora-trained-xl-folder")
156
  max_train_steps = gr.Number(label="Max Training Steps", value=500)
157
  checkpoint_steps = gr.Number(label="Checkpoints Steps", value=100)
158
+ remove_gpu = gr.Checkbox(label="Remove GPU After Training", value=True)
159
  train_button = gr.Button("Train !")
160
+
161
+
162
+ status = gr.Textbox(label="Training status")
163
 
164
  train_button.click(
165
  fn = main,
 
168
  model_output_folder,
169
  instance_prompt,
170
  max_train_steps,
171
+ checkpoint_steps,
172
+ remove_gpu
173
  ],
174
  outputs = [status]
175
  )