Commit
•
c9de947
1
Parent(s):
37393f2
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
4 |
import argparse
|
5 |
import shutil
|
6 |
from train_dreambooth import run_training
|
|
|
7 |
from PIL import Image
|
8 |
import torch
|
9 |
|
@@ -47,6 +48,8 @@ def swap_text(option):
|
|
47 |
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
|
48 |
|
49 |
def train(*inputs):
|
|
|
|
|
50 |
file_counter = 0
|
51 |
for i, input in enumerate(inputs):
|
52 |
if(i < maximum_concepts-1):
|
@@ -156,12 +159,23 @@ def train(*inputs):
|
|
156 |
max_train_steps=Training_Steps,
|
157 |
)
|
158 |
run_training(args_general)
|
159 |
-
|
160 |
shutil.rmtree('instance_images')
|
161 |
-
shutil.make_archive("
|
162 |
shutil.rmtree("output_model")
|
163 |
torch.cuda.empty_cache()
|
164 |
-
return [gr.update(visible=True, value="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
with gr.Blocks(css=css) as demo:
|
167 |
with gr.Box():
|
@@ -252,4 +266,6 @@ with gr.Blocks(css=css) as demo:
|
|
252 |
push_button = gr.Button("Push to the Hub")
|
253 |
result = gr.File(label="Download the uploaded models (zip file are diffusers weights, *.ckpt are CompVis/AUTOMATIC1111 weights)", visible=True)
|
254 |
train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub])
|
|
|
|
|
255 |
demo.launch()
|
|
|
4 |
import argparse
|
5 |
import shutil
|
6 |
from train_dreambooth import run_training
|
7 |
+
from converttosd import convert
|
8 |
from PIL import Image
|
9 |
import torch
|
10 |
|
|
|
48 |
return [f"You are going to train a `style`, upload 10-20 images of the style you are planning on training on. Name the files with the words you would like {mandatory_liability}:", '''<img src="file/trsl_style.png" />''', f"You should name your files with a unique word that represent your concept (e.g.: `{instance_prompt_example}` here). Images will be automatically cropped to 512x512.", freeze_for]
|
49 |
|
50 |
def train(*inputs):
|
51 |
+
if os.path.exists("diffusers_model.zip"): os.remove("diffusers_model.zip")
|
52 |
+
if os.path.exists("model.ckpt"): os.remove("model.ckpt")
|
53 |
file_counter = 0
|
54 |
for i, input in enumerate(inputs):
|
55 |
if(i < maximum_concepts-1):
|
|
|
159 |
max_train_steps=Training_Steps,
|
160 |
)
|
161 |
run_training(args_general)
|
162 |
+
convert("output_model", "model.ckpt")
|
163 |
shutil.rmtree('instance_images')
|
164 |
+
shutil.make_archive("diffusers_model", 'zip', "output_model")
|
165 |
shutil.rmtree("output_model")
|
166 |
torch.cuda.empty_cache()
|
167 |
+
return [gr.update(visible=True, value=["diffusers_model.zip", "model.ckpt"]), gr.update(visible=True), gr.update(visible=True)]
|
168 |
+
|
169 |
+
def generate(prompt):
|
170 |
+
from diffusers import StableDiffusionPipeline
|
171 |
+
|
172 |
+
pipe = StableDiffusionPipeline.from_pretrained("./output_model", torch_dtype=torch.float16)
|
173 |
+
pipe = pipe.to("cuda")
|
174 |
+
image = pipe(prompt).images[0]
|
175 |
+
return(image)
|
176 |
+
|
177 |
+
def push_button(path):
|
178 |
+
pass
|
179 |
|
180 |
with gr.Blocks(css=css) as demo:
|
181 |
with gr.Box():
|
|
|
266 |
push_button = gr.Button("Push to the Hub")
|
267 |
result = gr.File(label="Download the uploaded models (zip file are diffusers weights, *.ckpt are CompVis/AUTOMATIC1111 weights)", visible=True)
|
268 |
train_btn.click(fn=train, inputs=is_visible+concept_collection+file_collection+[type_of_thing]+[steps]+[perc_txt_encoder]+[swap_auto_calculated], outputs=[result, try_your_model, push_to_hub])
|
269 |
+
generate_button.click(fn=generate, inputs=prompt, outputs=result)
|
270 |
+
push_button.click(fn=push_to_hub, inputs=model_repo_tag, outputs=[])
|
271 |
demo.launch()
|