Spaces:
Running
on
A10G
Running
on
A10G
Commit
•
45e73ca
1
Parent(s):
968ec9f
Update pipeline_semantic_stable_diffusion_img2img_solver.py (#8)
Browse files- Update pipeline_semantic_stable_diffusion_img2img_solver.py (ec1516109c3c0bade7d64c6e18b300f19e391eab)
- Update app.py (21466b8ae643eaaed1223ac184b7680db840213c)
- Update pipeline_semantic_stable_diffusion_img2img_solver.py (996e7e071f43404cb60f344239e3bb45d2e93f4a)
Co-authored-by: Linoy Tsaban <LinoyTsaban@users.noreply.huggingface.co>
- app.py +18 -15
- pipeline_semantic_stable_diffusion_img2img_solver.py +8 -7
app.py
CHANGED
@@ -35,9 +35,9 @@ def caption_image(input_image):
|
|
35 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
36 |
return generated_caption, generated_caption
|
37 |
|
38 |
-
def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
|
39 |
latents = wts[-1].expand(1, -1, -1, -1)
|
40 |
-
img = pipe(
|
41 |
prompt=prompt_tar,
|
42 |
init_latents=latents,
|
43 |
guidance_scale=cfg_scale_tar,
|
@@ -45,9 +45,10 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
|
|
45 |
# num_inference_steps=steps,
|
46 |
# use_ddpm=True,
|
47 |
# wts=wts.value,
|
|
|
48 |
zs=zs,
|
49 |
-
)
|
50 |
-
return img
|
51 |
|
52 |
|
53 |
def reconstruct(
|
@@ -57,6 +58,7 @@ def reconstruct(
|
|
57 |
skip,
|
58 |
wts,
|
59 |
zs,
|
|
|
60 |
do_reconstruction,
|
61 |
reconstruction,
|
62 |
reconstruct_button,
|
@@ -77,8 +79,8 @@ def reconstruct(
|
|
77 |
): # if image caption was not changed, run actual reconstruction
|
78 |
tar_prompt = ""
|
79 |
latents = wts[-1].expand(1, -1, -1, -1)
|
80 |
-
reconstruction = sample(
|
81 |
-
zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
|
82 |
)
|
83 |
do_reconstruction = False
|
84 |
return (
|
@@ -128,7 +130,7 @@ def load_and_invert(
|
|
128 |
## SEGA ##
|
129 |
|
130 |
def edit(input_image,
|
131 |
-
wts, zs,
|
132 |
tar_prompt,
|
133 |
image_caption,
|
134 |
steps,
|
@@ -195,27 +197,27 @@ def edit(input_image,
|
|
195 |
)
|
196 |
|
197 |
latnets = wts[-1].expand(1, -1, -1, -1)
|
198 |
-
sega_out = pipe(prompt=tar_prompt,
|
199 |
init_latents=latnets,
|
200 |
guidance_scale = tar_cfg_scale,
|
201 |
# num_images_per_prompt=1,
|
202 |
# num_inference_steps=steps,
|
203 |
# use_ddpm=True,
|
204 |
# wts=wts.value,
|
205 |
-
zs=zs, **editing_args)
|
206 |
|
207 |
-
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
208 |
|
209 |
|
210 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
211 |
|
212 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
213 |
-
pure_ddpm_img = sample(zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
214 |
reconstruction = pure_ddpm_img
|
215 |
do_reconstruction = False
|
216 |
-
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
217 |
|
218 |
-
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
219 |
|
220 |
|
221 |
def randomize_seed_fn(seed, is_random):
|
@@ -458,6 +460,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
458 |
gr.HTML(intro)
|
459 |
wts = gr.State()
|
460 |
zs = gr.State()
|
|
|
461 |
reconstruction = gr.State()
|
462 |
do_inversion = gr.State(value=True)
|
463 |
do_reconstruction = gr.State(value=True)
|
@@ -693,7 +696,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
693 |
run_button.click(
|
694 |
fn=edit,
|
695 |
inputs=[input_image,
|
696 |
-
wts, zs,
|
697 |
tar_prompt,
|
698 |
image_caption,
|
699 |
steps,
|
@@ -713,7 +716,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
713 |
|
714 |
|
715 |
],
|
716 |
-
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs, do_inversion, share_btn_container])
|
717 |
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
|
718 |
|
719 |
|
|
|
35 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
36 |
return generated_caption, generated_caption
|
37 |
|
38 |
+
def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
|
39 |
latents = wts[-1].expand(1, -1, -1, -1)
|
40 |
+
img, attention_store = pipe(
|
41 |
prompt=prompt_tar,
|
42 |
init_latents=latents,
|
43 |
guidance_scale=cfg_scale_tar,
|
|
|
45 |
# num_inference_steps=steps,
|
46 |
# use_ddpm=True,
|
47 |
# wts=wts.value,
|
48 |
+
attention_store = attention_store,
|
49 |
zs=zs,
|
50 |
+
)
|
51 |
+
return img.images[0], attention_store
|
52 |
|
53 |
|
54 |
def reconstruct(
|
|
|
58 |
skip,
|
59 |
wts,
|
60 |
zs,
|
61 |
+
attention_store,
|
62 |
do_reconstruction,
|
63 |
reconstruction,
|
64 |
reconstruct_button,
|
|
|
79 |
): # if image caption was not changed, run actual reconstruction
|
80 |
tar_prompt = ""
|
81 |
latents = wts[-1].expand(1, -1, -1, -1)
|
82 |
+
reconstruction, attention_store = sample(
|
83 |
+
zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
|
84 |
)
|
85 |
do_reconstruction = False
|
86 |
return (
|
|
|
130 |
## SEGA ##
|
131 |
|
132 |
def edit(input_image,
|
133 |
+
wts, zs, attention_store,
|
134 |
tar_prompt,
|
135 |
image_caption,
|
136 |
steps,
|
|
|
197 |
)
|
198 |
|
199 |
latnets = wts[-1].expand(1, -1, -1, -1)
|
200 |
+
sega_out, attention_store = pipe(prompt=tar_prompt,
|
201 |
init_latents=latnets,
|
202 |
guidance_scale = tar_cfg_scale,
|
203 |
# num_images_per_prompt=1,
|
204 |
# num_inference_steps=steps,
|
205 |
# use_ddpm=True,
|
206 |
# wts=wts.value,
|
207 |
+
zs=zs, attention_store=attention_store, **editing_args)
|
208 |
|
209 |
+
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
|
210 |
|
211 |
|
212 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
213 |
|
214 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
215 |
+
pure_ddpm_img, attention_store = sample(zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
216 |
reconstruction = pure_ddpm_img
|
217 |
do_reconstruction = False
|
218 |
+
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
|
219 |
|
220 |
+
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
|
221 |
|
222 |
|
223 |
def randomize_seed_fn(seed, is_random):
|
|
|
460 |
gr.HTML(intro)
|
461 |
wts = gr.State()
|
462 |
zs = gr.State()
|
463 |
+
attention_store=gr.State()
|
464 |
reconstruction = gr.State()
|
465 |
do_inversion = gr.State(value=True)
|
466 |
do_reconstruction = gr.State(value=True)
|
|
|
696 |
run_button.click(
|
697 |
fn=edit,
|
698 |
inputs=[input_image,
|
699 |
+
wts, zs, attention_store,
|
700 |
tar_prompt,
|
701 |
image_caption,
|
702 |
steps,
|
|
|
716 |
|
717 |
|
718 |
],
|
719 |
+
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, do_inversion, share_btn_container])
|
720 |
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
|
721 |
|
722 |
|
pipeline_semantic_stable_diffusion_img2img_solver.py
CHANGED
@@ -499,6 +499,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
499 |
verbose=True,
|
500 |
use_cross_attn_mask: bool = False,
|
501 |
# Attention store (just for visualization purposes)
|
|
|
502 |
attn_store_steps: Optional[List[int]] = [],
|
503 |
store_averaged_over_steps: bool = True,
|
504 |
use_intersect_mask: bool = False,
|
@@ -771,8 +772,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
771 |
timesteps = timesteps[-zs.shape[0]:]
|
772 |
|
773 |
if use_cross_attn_mask:
|
774 |
-
|
775 |
-
self.prepare_unet(
|
776 |
# 5. Prepare latent variables
|
777 |
num_channels_latents = self.unet.config.in_channels
|
778 |
latents = self.prepare_latents(
|
@@ -917,8 +918,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
917 |
noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
|
918 |
|
919 |
if use_cross_attn_mask:
|
920 |
-
out =
|
921 |
-
attention_maps=
|
922 |
prompts=self.text_cross_attention_maps,
|
923 |
res=16,
|
924 |
from_where=["up", "down"],
|
@@ -1080,7 +1081,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
1080 |
store_step = i in attn_store_steps
|
1081 |
if store_step:
|
1082 |
print(f"storing attention for step {i}")
|
1083 |
-
|
1084 |
|
1085 |
# call the callback, if provided
|
1086 |
if callback is not None and i % callback_steps == 0:
|
@@ -1102,9 +1103,9 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
1102 |
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
1103 |
|
1104 |
if not return_dict:
|
1105 |
-
return (image, has_nsfw_concept)
|
1106 |
|
1107 |
-
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
1108 |
|
1109 |
def encode_text(self, prompts):
|
1110 |
text_inputs = self.tokenizer(
|
|
|
499 |
verbose=True,
|
500 |
use_cross_attn_mask: bool = False,
|
501 |
# Attention store (just for visualization purposes)
|
502 |
+
attention_store = None,
|
503 |
attn_store_steps: Optional[List[int]] = [],
|
504 |
store_averaged_over_steps: bool = True,
|
505 |
use_intersect_mask: bool = False,
|
|
|
772 |
timesteps = timesteps[-zs.shape[0]:]
|
773 |
|
774 |
if use_cross_attn_mask:
|
775 |
+
attention_store = AttentionStore(average=store_averaged_over_steps, batch_size=batch_size)
|
776 |
+
self.prepare_unet(attention_store, PnP=False)
|
777 |
# 5. Prepare latent variables
|
778 |
num_channels_latents = self.unet.config.in_channels
|
779 |
latents = self.prepare_latents(
|
|
|
918 |
noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
|
919 |
|
920 |
if use_cross_attn_mask:
|
921 |
+
out = attention_store.aggregate_attention(
|
922 |
+
attention_maps=attention_store.step_store,
|
923 |
prompts=self.text_cross_attention_maps,
|
924 |
res=16,
|
925 |
from_where=["up", "down"],
|
|
|
1081 |
store_step = i in attn_store_steps
|
1082 |
if store_step:
|
1083 |
print(f"storing attention for step {i}")
|
1084 |
+
attention_store.between_steps(store_step)
|
1085 |
|
1086 |
# call the callback, if provided
|
1087 |
if callback is not None and i % callback_steps == 0:
|
|
|
1103 |
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
1104 |
|
1105 |
if not return_dict:
|
1106 |
+
return (image, has_nsfw_concept), attention_store
|
1107 |
|
1108 |
+
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store
|
1109 |
|
1110 |
def encode_text(self, prompts):
|
1111 |
text_inputs = self.tokenizer(
|