multimodalart HF staff commited on
Commit
c9de947
1 Parent(s): 37393f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
4
  import argparse
5
  import shutil
6
  from train_dreambooth import run_training
 
7
  from PIL import Image
8
  import torch
9
 
@@ -47,6 +48,8 @@ def swap_text(option):
47
  return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
48
 
49
  def train(*inputs):
 
 
50
  file_counter = 0
51
  for i, input in enumerate(inputs):
52
  if(i < maximum_concepts-1):
@@ -156,12 +159,23 @@ def train(*inputs):
156
  max_train_steps=Training_Steps,
157
  )
158
  run_training(args_general)
159
-
160
  shutil.rmtree('instance_images')
161
- shutil.make_archive("output_model", 'zip', "output_model")
162
  shutil.rmtree("output_model")
163
  torch.cuda.empty_cache()
164
- return [gr.update(visible=True, value="output_model.zip"), gr.update(visible=True), gr.update(visible=True)]
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  with gr.Blocks(css=css) as demo:
167
  with gr.Box():
@@ -252,4 +266,6 @@ with gr.Blocks(css=css) as demo:
252
  push_button = gr.Button("Push to the Hub")
253
  result = gr.File(label="Download the uploaded models (zip file are diffusers weights, *.ckpt are CompVis/AUTOMATIC1111 weights)", visible=True)
254
  train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub])
 
 
255
  demo.launch()
 
4
  import argparse
5
  import shutil
6
  from train_dreambooth import run_training
7
+ from converttosd import convert
8
  from PIL import Image
9
  import torch
10
 
 
48
  return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
49
 
50
  def train(*inputs):
51
+ if os.path.exists("diffusers_model.zip"): os.remove("diffusers_model.zip")
52
+ if os.path.exists("model.ckpt"): os.remove("model.ckpt")
53
  file_counter = 0
54
  for i, input in enumerate(inputs):
55
  if(i < maximum_concepts-1):
 
159
  max_train_steps=Training_Steps,
160
  )
161
  run_training(args_general)
162
+ convert("output_model", "model.ckpt")
163
  shutil.rmtree('instance_images')
164
+ shutil.make_archive("diffusers_model", 'zip', "output_model")
165
  shutil.rmtree("output_model")
166
  torch.cuda.empty_cache()
167
+ return [gr.update(visible=True, value=["diffusers_model.zip", "model.ckpt"]), gr.update(visible=True), gr.update(visible=True)]
168
+
169
+ def generate(prompt):
170
+ from diffusers import StableDiffusionPipeline
171
+
172
+ pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
173
+ pipe = pipe.to("cuda")
174
+ image = pipe(prompt).images[0]
175
+ return(image)
176
+
177
+ def push_button(path):
178
+ pass
179
 
180
  with gr.Blocks(css=css) as demo:
181
  with gr.Box():
 
266
  push_button = gr.Button("Push to the Hub")
267
  result = gr.File(label="Download the uploaded models (zip file are diffusers weights, *.ckpt are CompVis/AUTOMATIC1111 weights)", visible=True)
268
  train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub])
269
+ generate_button.click(fn=generate, inputs=prompt, outputs=result)
270
+ push_button.click(fn=push_to_hub, inputs=model_repo_tag, outputs=[])
271
  demo.launch()