Spaces:
Running
on
A10G
Running
on
A10G
Linoy Tsaban
commited on
Commit
•
76afba1
1
Parent(s):
c71b83b
Update app.py
Browse files
app.py
CHANGED
@@ -34,61 +34,76 @@ def caption_image(input_image):
|
|
34 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
35 |
return generated_caption, generated_caption
|
36 |
|
37 |
-
def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
48 |
return img
|
49 |
|
50 |
-
def reconstruct(tar_prompt,
|
51 |
-
image_caption,
|
52 |
-
tar_cfg_scale,
|
53 |
-
skip,
|
54 |
-
wts, zs,
|
55 |
-
do_reconstruction,
|
56 |
-
reconstruction,
|
57 |
-
reconstruct_button
|
58 |
-
):
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
if reconstruct_button == "Hide Reconstruction":
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
else:
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
|
76 |
def load_and_invert(
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
89 |
):
|
90 |
-
|
91 |
-
|
92 |
# x0 = load_512(input_image, device=device).to(torch.float16)
|
93 |
|
94 |
if do_inversion or randomize_seed:
|
@@ -96,16 +111,14 @@ def load_and_invert(
|
|
96 |
seed = randomize_seed_fn()
|
97 |
seed_everything(seed)
|
98 |
# invert and retrieve noise maps and latent
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
wts = gr.State(value=wts_tensor)
|
108 |
-
zs = gr.State(value=zs_tensor)
|
109 |
do_inversion = False
|
110 |
|
111 |
return wts, zs, do_inversion, inversion_progress.update(visible=False)
|
@@ -171,6 +184,8 @@ def edit(input_image,
|
|
171 |
edit_warmup_steps=[warmup_1, warmup_2, warmup_3,],
|
172 |
edit_guidance_scale=[guidnace_scale_1,guidnace_scale_2,guidnace_scale_3],
|
173 |
edit_threshold=[threshold_1, threshold_2, threshold_3],
|
|
|
|
|
174 |
eta=1,
|
175 |
use_cross_attn_mask=use_cross_attn_mask,
|
176 |
use_intersect_mask=use_intersect_mask
|
|
|
34 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
35 |
return generated_caption, generated_caption
|
36 |
|
37 |
+
def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
|
38 |
+
latents = wts[-1].expand(1, -1, -1, -1)
|
39 |
+
img = pipe(
|
40 |
+
prompt=prompt_tar,
|
41 |
+
init_latents=latents,
|
42 |
+
guidance_scale=cfg_scale_tar,
|
43 |
+
# num_images_per_prompt=1,
|
44 |
+
# num_inference_steps=steps,
|
45 |
+
# use_ddpm=True,
|
46 |
+
# wts=wts.value,
|
47 |
+
zs=zs,
|
48 |
+
).images[0]
|
49 |
return img
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
def reconstruct(
|
53 |
+
tar_prompt,
|
54 |
+
image_caption,
|
55 |
+
tar_cfg_scale,
|
56 |
+
skip,
|
57 |
+
wts,
|
58 |
+
zs,
|
59 |
+
do_reconstruction,
|
60 |
+
reconstruction,
|
61 |
+
reconstruct_button,
|
62 |
+
):
|
63 |
if reconstruct_button == "Hide Reconstruction":
|
64 |
+
return (
|
65 |
+
reconstruction,
|
66 |
+
reconstruction,
|
67 |
+
ddpm_edited_image.update(visible=False),
|
68 |
+
do_reconstruction,
|
69 |
+
"Show Reconstruction",
|
70 |
+
)
|
71 |
|
72 |
else:
|
73 |
+
if do_reconstruction:
|
74 |
+
if (
|
75 |
+
image_caption.lower() == tar_prompt.lower()
|
76 |
+
): # if image caption was not changed, run actual reconstruction
|
77 |
+
tar_prompt = ""
|
78 |
+
latents = wts[-1].expand(1, -1, -1, -1)
|
79 |
+
reconstruction = sample(
|
80 |
+
zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
|
81 |
+
)
|
82 |
+
do_reconstruction = False
|
83 |
+
return (
|
84 |
+
reconstruction,
|
85 |
+
reconstruction,
|
86 |
+
ddpm_edited_image.update(visible=True),
|
87 |
+
do_reconstruction,
|
88 |
+
"Hide Reconstruction",
|
89 |
+
)
|
90 |
|
91 |
|
92 |
def load_and_invert(
|
93 |
+
input_image,
|
94 |
+
do_inversion,
|
95 |
+
seed,
|
96 |
+
randomize_seed,
|
97 |
+
wts,
|
98 |
+
zs,
|
99 |
+
src_prompt="",
|
100 |
+
# tar_prompt="",
|
101 |
+
steps=30,
|
102 |
+
src_cfg_scale=3.5,
|
103 |
+
skip=15,
|
104 |
+
tar_cfg_scale=15,
|
105 |
+
progress=gr.Progress(track_tqdm=True),
|
106 |
):
|
|
|
|
|
107 |
# x0 = load_512(input_image, device=device).to(torch.float16)
|
108 |
|
109 |
if do_inversion or randomize_seed:
|
|
|
111 |
seed = randomize_seed_fn()
|
112 |
seed_everything(seed)
|
113 |
# invert and retrieve noise maps and latent
|
114 |
+
zs, wts = pipe.invert(
|
115 |
+
image_path=input_image,
|
116 |
+
source_prompt=src_prompt,
|
117 |
+
source_guidance_scale=src_cfg_scale,
|
118 |
+
num_inversion_steps=steps,
|
119 |
+
skip=skip,
|
120 |
+
eta=1.0,
|
121 |
+
)
|
|
|
|
|
122 |
do_inversion = False
|
123 |
|
124 |
return wts, zs, do_inversion, inversion_progress.update(visible=False)
|
|
|
184 |
edit_warmup_steps=[warmup_1, warmup_2, warmup_3,],
|
185 |
edit_guidance_scale=[guidnace_scale_1,guidnace_scale_2,guidnace_scale_3],
|
186 |
edit_threshold=[threshold_1, threshold_2, threshold_3],
|
187 |
+
edit_momentum_scale=0,
|
188 |
+
edit_mom_beta=0.6,
|
189 |
eta=1,
|
190 |
use_cross_attn_mask=use_cross_attn_mask,
|
191 |
use_intersect_mask=use_intersect_mask
|