Linoy Tsaban
commited on
Commit
•
6a5a59b
1
Parent(s):
af56f98
Update app.py
Browse files
app.py
CHANGED
@@ -14,12 +14,6 @@ import re
|
|
14 |
|
15 |
|
16 |
|
17 |
-
def randomize_seed_fn(seed, randomize_seed):
|
18 |
-
if randomize_seed:
|
19 |
-
seed = random.randint(0, np.iinfo(np.int32).max)
|
20 |
-
torch.manual_seed(seed)
|
21 |
-
return seed
|
22 |
-
|
23 |
|
24 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
25 |
|
@@ -116,8 +110,29 @@ def get_example():
|
|
116 |
]]
|
117 |
return case
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
input_image,
|
122 |
do_inversion,
|
123 |
seed, randomize_seed,
|
@@ -127,7 +142,7 @@ def invert_and_reconstruct(
|
|
127 |
steps=100,
|
128 |
src_cfg_scale = 3.5,
|
129 |
skip=36,
|
130 |
-
tar_cfg_scale=15
|
131 |
|
132 |
):
|
133 |
|
@@ -140,10 +155,7 @@ def invert_and_reconstruct(
|
|
140 |
wts = gr.State(value=wts_tensor)
|
141 |
zs = gr.State(value=zs_tensor)
|
142 |
do_inversion = False
|
143 |
-
|
144 |
-
# output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
145 |
-
|
146 |
-
# return output, wts, zs, do_inversion
|
147 |
return wts, zs, do_inversion
|
148 |
|
149 |
|
@@ -244,7 +256,10 @@ with gr.Blocks(css='style.css') as demo:
|
|
244 |
else:
|
245 |
return row2.update(visible=True), row3.update(visible=True), plus.update(visible=False), 3
|
246 |
|
247 |
-
|
|
|
|
|
|
|
248 |
def reset_do_inversion():
|
249 |
do_inversion = True
|
250 |
return do_inversion
|
@@ -255,15 +270,16 @@ with gr.Blocks(css='style.css') as demo:
|
|
255 |
zs = gr.State()
|
256 |
do_inversion = gr.State(value=True)
|
257 |
sega_concepts_counter = gr.State(1)
|
|
|
258 |
|
259 |
|
260 |
|
261 |
with gr.Row():
|
262 |
input_image = gr.Image(label="Input Image", interactive=True)
|
263 |
-
|
264 |
sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
|
265 |
input_image.style(height=365, width=365)
|
266 |
-
|
267 |
sega_edited_image.style(height=365, width=365)
|
268 |
|
269 |
with gr.Tabs() as tabs:
|
@@ -322,12 +338,13 @@ with gr.Blocks(css='style.css') as demo:
|
|
322 |
)
|
323 |
|
324 |
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
325 |
-
|
326 |
|
327 |
|
328 |
with gr.Row():
|
329 |
with gr.Column(scale=1, min_width=100):
|
330 |
run_button = gr.Button("Run")
|
|
|
331 |
# with gr.Column(scale=1, min_width=100):
|
332 |
# edit_button = gr.Button("Edit")
|
333 |
|
@@ -350,16 +367,25 @@ with gr.Blocks(css='style.css') as demo:
|
|
350 |
|
351 |
|
352 |
|
353 |
-
|
354 |
outputs= [row2, row3, plus, sega_concepts_counter], queue = False)
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
run_button.click(
|
358 |
fn = randomize_seed_fn,
|
359 |
inputs = [seed, randomize_seed],
|
360 |
outputs = [seed],
|
361 |
queue = False).then(
|
362 |
-
fn=
|
363 |
inputs=[input_image,
|
364 |
do_inversion,
|
365 |
seed, randomize_seed,
|
@@ -369,10 +395,10 @@ with gr.Blocks(css='style.css') as demo:
|
|
369 |
steps,
|
370 |
src_cfg_scale,
|
371 |
skip,
|
372 |
-
tar_cfg_scale
|
373 |
],
|
374 |
-
# outputs=[ddpm_edited_image, wts, zs, do_inversion],
|
375 |
outputs=[wts, zs, do_inversion],
|
|
|
376 |
).success(
|
377 |
fn=edit,
|
378 |
inputs=[input_image,
|
@@ -389,8 +415,17 @@ with gr.Blocks(css='style.css') as demo:
|
|
389 |
|
390 |
],
|
391 |
outputs=[sega_edited_image],
|
|
|
|
|
|
|
392 |
)
|
393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
# Automatically start inverting upon input_image change
|
395 |
input_image.change(
|
396 |
fn = reset_do_inversion,
|
|
|
14 |
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
19 |
|
|
|
110 |
]]
|
111 |
return case
|
112 |
|
113 |
+
def randomize_seed_fn(seed, randomize_seed):
|
114 |
+
if randomize_seed:
|
115 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
116 |
+
torch.manual_seed(seed)
|
117 |
+
return seed
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
|
122 |
+
def reconstruct(tar_prompt,
|
123 |
+
tar_cfg_scale,
|
124 |
+
skip,
|
125 |
+
wts, zs,
|
126 |
+
# do_reconstruction,
|
127 |
+
# reconstruction
|
128 |
+
)
|
129 |
+
|
130 |
+
):
|
131 |
+
# if do_reconstruction:
|
132 |
+
reconstruction = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
133 |
+
return reconstruction
|
134 |
+
|
135 |
+
def load_and_invert(
|
136 |
input_image,
|
137 |
do_inversion,
|
138 |
seed, randomize_seed,
|
|
|
142 |
steps=100,
|
143 |
src_cfg_scale = 3.5,
|
144 |
skip=36,
|
145 |
+
tar_cfg_scale=15
|
146 |
|
147 |
):
|
148 |
|
|
|
155 |
wts = gr.State(value=wts_tensor)
|
156 |
zs = gr.State(value=zs_tensor)
|
157 |
do_inversion = False
|
158 |
+
|
|
|
|
|
|
|
159 |
return wts, zs, do_inversion
|
160 |
|
161 |
|
|
|
256 |
else:
|
257 |
return row2.update(visible=True), row3.update(visible=True), plus.update(visible=False), 3
|
258 |
|
259 |
+
def show_reconstruction_option():
|
260 |
+
return reconstruct_button.update(visible=True)
|
261 |
+
|
262 |
+
|
263 |
def reset_do_inversion():
|
264 |
do_inversion = True
|
265 |
return do_inversion
|
|
|
270 |
zs = gr.State()
|
271 |
do_inversion = gr.State(value=True)
|
272 |
sega_concepts_counter = gr.State(1)
|
273 |
+
# reconstruction = gr.State()
|
274 |
|
275 |
|
276 |
|
277 |
with gr.Row():
|
278 |
input_image = gr.Image(label="Input Image", interactive=True)
|
279 |
+
ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False, visible=False)
|
280 |
sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
|
281 |
input_image.style(height=365, width=365)
|
282 |
+
ddpm_edited_image.style(height=512, width=512)
|
283 |
sega_edited_image.style(height=365, width=365)
|
284 |
|
285 |
with gr.Tabs() as tabs:
|
|
|
338 |
)
|
339 |
|
340 |
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
341 |
+
add_concept_button = gr.Button("+")
|
342 |
|
343 |
|
344 |
with gr.Row():
|
345 |
with gr.Column(scale=1, min_width=100):
|
346 |
run_button = gr.Button("Run")
|
347 |
+
reconstruct_button = gr.Button("Show me the reconstruction")
|
348 |
# with gr.Column(scale=1, min_width=100):
|
349 |
# edit_button = gr.Button("Edit")
|
350 |
|
|
|
367 |
|
368 |
|
369 |
|
370 |
+
add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
|
371 |
outputs= [row2, row3, plus, sega_concepts_counter], queue = False)
|
372 |
|
373 |
+
reconstruct_button.click(
|
374 |
+
fn = reconstruct,
|
375 |
+
inputs = [tar_prompt,
|
376 |
+
tar_cfg_scale,
|
377 |
+
skip,
|
378 |
+
wts, zs]
|
379 |
+
outputs = [ddpm_edited_image]
|
380 |
+
)
|
381 |
+
|
382 |
|
383 |
run_button.click(
|
384 |
fn = randomize_seed_fn,
|
385 |
inputs = [seed, randomize_seed],
|
386 |
outputs = [seed],
|
387 |
queue = False).then(
|
388 |
+
fn=load_and_invert,
|
389 |
inputs=[input_image,
|
390 |
do_inversion,
|
391 |
seed, randomize_seed,
|
|
|
395 |
steps,
|
396 |
src_cfg_scale,
|
397 |
skip,
|
398 |
+
tar_cfg_scale
|
399 |
],
|
|
|
400 |
outputs=[wts, zs, do_inversion],
|
401 |
+
|
402 |
).success(
|
403 |
fn=edit,
|
404 |
inputs=[input_image,
|
|
|
415 |
|
416 |
],
|
417 |
outputs=[sega_edited_image],
|
418 |
+
).success(
|
419 |
+
fn = show_reconstruction_option,
|
420 |
+
outputs = [reconstruct_button]
|
421 |
)
|
422 |
|
423 |
+
reconstruct_button.click(
|
424 |
+
fn =
|
425 |
+
)
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
# Automatically start inverting upon input_image change
|
430 |
input_image.change(
|
431 |
fn = reset_do_inversion,
|