fffiloni commited on
Commit
e45a98d
1 Parent(s): 9ea3009

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -10
app.py CHANGED
@@ -21,7 +21,7 @@ pipe = ImagicStableDiffusionPipeline.from_pretrained(
21
 
22
  generator = torch.Generator("cuda").manual_seed(0)
23
 
24
- def infer(prompt, init_image, trn_steps):
25
  init_image = Image.open(init_image).convert("RGB")
26
  init_image = init_image.resize((256, 256))
27
 
@@ -32,7 +32,7 @@ def infer(prompt, init_image, trn_steps):
32
  guidance_scale=7.5,
33
  num_inference_steps=50,
34
  generator=generator,
35
- text_embedding_optimization_steps=500,
36
  model_fine_tuning_optimization_steps=trn_steps)
37
 
38
  with torch.no_grad():
@@ -40,12 +40,35 @@ def infer(prompt, init_image, trn_steps):
40
 
41
 
42
 
43
- res = pipe(alpha=1)
44
 
45
 
46
- return res.images[0]
 
47
 
 
 
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  title = """
50
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
51
  <div
@@ -117,17 +140,23 @@ with gr.Blocks(css=css) as block:
117
 
118
  prompt_input = gr.Textbox(label="Target text", placeholder="Describe the image with what you want to change about the subject")
119
  image_init = gr.Image(source="upload", type="filepath",label="Input Image")
120
- trn_steps = gr.Slider(250, 1000, value=500, label="finetuning steps")
121
- submit_btn = gr.Button("Train")
122
-
 
 
 
 
 
123
  image_output = gr.Image(label="Edited image")
124
 
125
- examples=[['a sitting dog','imagic-dog.png', 250], ['a photo of a bird spreading wings','imagic-bird.png',250]]
126
- ex = gr.Examples(examples=examples, fn=infer, inputs=[prompt_input,image_init,trn_steps], outputs=[image_output], cache_examples=False, run_on_click=False)
127
 
128
 
129
  gr.HTML(article)
130
 
131
- submit_btn.click(fn=infer, inputs=[prompt_input,image_init,trn_steps], outputs=[image_output])
 
132
 
133
  block.queue(max_size=12).launch(show_api=False)
 
21
 
22
  generator = torch.Generator("cuda").manual_seed(0)
23
 
24
+ def train(prompt, init_image, trn_text, trn_steps):
25
  init_image = Image.open(init_image).convert("RGB")
26
  init_image = init_image.resize((256, 256))
27
 
 
32
  guidance_scale=7.5,
33
  num_inference_steps=50,
34
  generator=generator,
35
+ text_embedding_optimization_steps=trn_text,
36
  model_fine_tuning_optimization_steps=trn_steps)
37
 
38
  with torch.no_grad():
 
40
 
41
 
42
 
 
43
 
44
 
45
+
46
+ return "Training is finished !"
47
 
48
+ def generate(prompt, init_image):
49
+ init_image = Image.open(init_image).convert("RGB")
50
+ init_image = init_image.resize((256, 256))
51
 
52
+
53
+ res = pipe.train(
54
+ prompt,
55
+ init_image,
56
+ guidance_scale=7.5,
57
+ num_inference_steps=50,
58
+ generator=generator,
59
+ text_embedding_optimization_steps=0,
60
+ model_fine_tuning_optimization_steps=0)
61
+
62
+ with torch.no_grad():
63
+ torch.cuda.empty_cache()
64
+
65
+
66
+
67
+ res = pipe(alpha=1)
68
+
69
+
70
+ return res.images[0]
71
+
72
  title = """
73
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
74
  <div
 
140
 
141
  prompt_input = gr.Textbox(label="Target text", placeholder="Describe the image with what you want to change about the subject")
142
  image_init = gr.Image(source="upload", type="filepath",label="Input Image")
143
+ with gr.Row():
144
+ trn_text = gr.Slider(100, 500, value=250, label="text embedding")
145
+ trn_steps = gr.Slider(250, 1000, value=500, label="finetuning steps")
146
+ with gr.Row():
147
+ train_btn = gr.Button("1.Train")
148
+ gen_btn = gr.Button("2.Generate")
149
+
150
+ training_status = gr.Textbox(label="training status")
151
  image_output = gr.Image(label="Edited image")
152
 
153
+ #examples=[['a sitting dog','imagic-dog.png', 250], ['a photo of a bird spreading wings','imagic-bird.png',250]]
154
+ #ex = gr.Examples(examples=examples, fn=infer, inputs=[prompt_input,image_init,trn_steps], outputs=[image_output], cache_examples=False, run_on_click=False)
155
 
156
 
157
  gr.HTML(article)
158
 
159
+ train_btn.click(fn=train, inputs=[prompt_input,image_init,trn_text,trn_steps], outputs=[training_status])
160
+ gen_btn.click(fn=generate, inputs=[prompt_input,image_init], outputs=[image_output])
161
 
162
  block.queue(max_size=12).launch(show_api=False)