multimodalart HF staff commited on
Commit
ac586a8
1 Parent(s): cb024a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -16,21 +16,6 @@ css = '''
16
  shutil.unpack_archive("mix.zip", "mix")
17
  model_to_load = "multimodalart/sd-fine-tunable"
18
  maximum_concepts = 3
19
- def swap_values_files(*total_files):
20
- file_counter = 0
21
- for files in total_files:
22
- if(files):
23
- for file in files:
24
- filename = Path(file.orig_name).stem
25
- pt=''.join([i for i in filename if not i.isdigit()])
26
- pt=pt.replace("_"," ")
27
- pt=pt.replace("(","")
28
- pt=pt.replace(")","")
29
- instance_prompt = pt
30
- print(instance_prompt)
31
- file_counter += 1
32
- training_steps = (file_counter*200)
33
- return training_steps
34
 
35
  def swap_text(option):
36
  mandatory_liability = "You must have the right to do so and you are liable for the images you use"
@@ -47,6 +32,24 @@ def swap_text(option):
47
  freeze_for = 10
48
  return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def train(*inputs):
51
  if os.path.exists("diffusers_model.zip"): os.remove("diffusers_model.zip")
52
  if os.path.exists("model.ckpt"): os.remove("model.ckpt")
@@ -164,7 +167,7 @@ def train(*inputs):
164
  shutil.rmtree('instance_images')
165
  shutil.make_archive("diffusers_model", 'zip', "output_model")
166
  torch.cuda.empty_cache()
167
- return [gr.update(visible=True, value=["diffusers_model.zip"]), gr.update(visible=True), gr.update(visible=True)]
168
 
169
  def generate(prompt):
170
  from diffusers import StableDiffusionPipeline
@@ -177,7 +180,7 @@ def generate(prompt):
177
  def push(path):
178
  pass
179
 
180
- def convert():
181
  convert("output_model", "model.ckpt")
182
  return gr.update(visible=True, value=["diffusers_model.zip", "model.ckpt"])
183
 
@@ -192,6 +195,13 @@ with gr.Blocks(css=css) as demo:
192
  <img class="arrow" src="file/arrow.png" />
193
  </div>
194
  ''')
 
 
 
 
 
 
 
195
  gr.Markdown("# Dreambooth training")
196
  gr.Markdown("Customize Stable Diffusion by giving it with few-shot examples")
197
  with gr.Row():
@@ -253,10 +263,11 @@ with gr.Blocks(css=css) as demo:
253
  steps = gr.Number(label="How many steps", value=800)
254
  perc_txt_encoder = gr.Number(label="Percentage of the training steps the text-encoder should be trained as well", value=30)
255
 
256
- #for file in file_collection:
257
- # file.change(fn=swap_values_files, inputs=file_collection, outputs=[steps])
258
 
259
  type_of_thing.change(fn=swap_text, inputs=[type_of_thing], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder], queue=False)
 
260
  train_btn = gr.Button("Start Training")
261
  with gr.Box(visible=False) as try_your_model:
262
  gr.Markdown("Try your model")
@@ -268,11 +279,11 @@ with gr.Blocks(css=css) as demo:
268
  gr.Markdown("Push to Hugging Face Hub")
269
  model_repo_tag = gr.Textbox(label="Model name or URL", placeholder="username/model_name")
270
  push_button = gr.Button("Push to the Hub")
271
- result = gr.File(label="Download the uploaded models in the diffusers format (zip file are diffusers weights are CompVis/AUTOMATIC1111 weights)", visible=True)
272
- convert_button = gr.Button("Convert to CKPT")
273
 
274
- train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub])
275
  generate_button.click(fn=generate, inputs=prompt, outputs=result)
276
  push_button.click(fn=push, inputs=model_repo_tag, outputs=[])
277
- convert_button.click(fn=convert, inputs=[], outputs=result)
278
  demo.launch()
 
16
  shutil.unpack_archive("mix.zip", "mix")
17
  model_to_load = "multimodalart/sd-fine-tunable"
18
  maximum_concepts = 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def swap_text(option):
21
  mandatory_liability = "You must have the right to do so and you are liable for the images you use"
 
32
  freeze_for = 10
33
  return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
34
 
35
+ def count_files(*inputs):
36
+ file_counter = 0
37
+ for i, input in enumerate(inputs):
38
+ if(i < maximum_concepts-1):
39
+ if(input):
40
+ files = inputs[i+(maximum_concepts*2)]
41
+ for j, tile_temp in enumerate(files):
42
+ file_counter+= 1
43
+ uses_custom = inputs[-1]
44
+ type_of_thing = inputs[-4]
45
+ if(uses_custom):
46
+ Training_Steps = int(inputs[-3])
47
+ else:
48
+ if(type_of_thing == "person"):
49
+ Training_Steps = file_counter*200*2
50
+ else:
51
+ Training_Steps = file_counter*200
52
+ return(gr.update(visible=True, value=f"You are going to train {file_counter} files for {Training_Steps} steps. This should take around {round(Training_Steps/1.5, 2)} seconds, or {round((Training_Steps/1.5)/3600, 2)}. The T4 GPU costs US$0.60 for 1h, so the estimated costs for this training run should be {round(((Training_Steps/1.5)/3600)*0.6, 2)}"))
53
  def train(*inputs):
54
  if os.path.exists("diffusers_model.zip"): os.remove("diffusers_model.zip")
55
  if os.path.exists("model.ckpt"): os.remove("model.ckpt")
 
167
  shutil.rmtree('instance_images')
168
  shutil.make_archive("diffusers_model", 'zip', "output_model")
169
  torch.cuda.empty_cache()
170
+ return [gr.update(visible=True, value=["diffusers_model.zip"]), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
171
 
172
  def generate(prompt):
173
  from diffusers import StableDiffusionPipeline
 
180
  def push(path):
181
  pass
182
 
183
+ def convert_to_ckpt():
184
  convert("output_model", "model.ckpt")
185
  return gr.update(visible=True, value=["diffusers_model.zip", "model.ckpt"])
186
 
 
195
  <img class="arrow" src="file/arrow.png" />
196
  </div>
197
  ''')
198
+ else:
199
+ gr.HTML('''
200
+ <div class="gr-prose" style="max-width: 80%">
201
+ <h2>You have successfully cloned the Dreambooth Training Space</h2>
202
+ <p><a href="#">Now you can attribute a T4 GPU to it</a> (by going to the Settings tab) and run the training below. The GPU will be automatically unassigned after training is over. So you will be billed by the minute between when you activate the GPU and when it finishes training.</p>
203
+ </div>
204
+ ''')
205
  gr.Markdown("# Dreambooth training")
206
  gr.Markdown("Customize Stable Diffusion by giving it with few-shot examples")
207
  with gr.Row():
 
263
  steps = gr.Number(label="How many steps", value=800)
264
  perc_txt_encoder = gr.Number(label="Percentage of the training steps the text-encoder should be trained as well", value=30)
265
 
266
+ for file in file_collection:
267
+ file.change(fn=count_files, inputs=file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[training_summary, training_summary])
268
 
269
  type_of_thing.change(fn=swap_text, inputs=[type_of_thing], outputs=[thing_description, thing_image_example, things_naming, perc_txt_encoder], queue=False)
270
+ training_summary = gr.Textbox("", visible=False, label="Training Summary")
271
  train_btn = gr.Button("Start Training")
272
  with gr.Box(visible=False) as try_your_model:
273
  gr.Markdown("Try your model")
 
279
  gr.Markdown("Push to Hugging Face Hub")
280
  model_repo_tag = gr.Textbox(label="Model name or URL", placeholder="username/model_name")
281
  push_button = gr.Button("Push to the Hub")
282
+ result = gr.File(label="Download the uploaded models in the diffusers format", visible=True)
283
+ convert_button = gr.Button("Convert to CKPT", visible=False)
284
 
285
+ train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub, convert_button])
286
  generate_button.click(fn=generate, inputs=prompt, outputs=result)
287
  push_button.click(fn=push, inputs=model_repo_tag, outputs=[])
288
+ convert_button.click(fn=convert_to_ckpt, inputs=[], outputs=result)
289
  demo.launch()