leditsplusplus / app.py
multimodalart's picture
Update pipeline_semantic_stable_diffusion_img2img_solver.py (#9)
345d7b4
raw
history blame
40.6 kB
import gradio as gr
import torch
import numpy as np
import requests
import random
from io import BytesIO
from utils import *
from constants import *
from pipeline_semantic_stable_diffusion_img2img_solver import SemanticStableDiffusionImg2ImgPipeline_DPMSolver
from torch import autocast, inference_mode
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers.schedulers import DDIMScheduler
from scheduling_dpmsolver_multistep_inject import DPMSolverMultistepSchedulerInject
from transformers import AutoProcessor, BlipForConditionalGeneration
from share_btn import community_icon_html, loading_icon_html, share_js
# load pipelines
# sd_model_id = "runwayml/stable-diffusion-v1-5"
sd_model_id = "stabilityai/stable-diffusion-2-1-base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
pipe = SemanticStableDiffusionImg2ImgPipeline_DPMSolver.from_pretrained(sd_model_id,vae=vae,torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False).to(device)
pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id, subfolder="scheduler"
, algorithm_type="sde-dpmsolver++", solver_order=2)
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
## IMAGE CPATIONING ##
def caption_image(input_image):
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
pixel_values = inputs.pixel_values
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption, generated_caption
def sample(zs, wts, attention_store, text_cross_attention_maps, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
latents = wts[-1].expand(1, -1, -1, -1)
img, attention_store, text_cross_attention_maps = pipe(
prompt=prompt_tar,
init_latents=latents,
guidance_scale=cfg_scale_tar,
# num_images_per_prompt=1,
# num_inference_steps=steps,
# use_ddpm=True,
# wts=wts.value,
attention_store = attention_store, text_cross_attention_maps=text_cross_attention_maps,
zs=zs,
)
return img.images[0], attention_store, text_cross_attention_maps
def reconstruct(
tar_prompt,
image_caption,
tar_cfg_scale,
skip,
wts,
zs,
attention_store,
text_cross_attention_maps,
do_reconstruction,
reconstruction,
reconstruct_button,
):
if reconstruct_button == "Hide Reconstruction":
return (
reconstruction,
reconstruction,
gr.update(visible=False),
do_reconstruction,
"Show Reconstruction",
)
else:
if do_reconstruction:
if (
image_caption.lower() == tar_prompt.lower()
): # if image caption was not changed, run actual reconstruction
tar_prompt = ""
latents = wts[-1].expand(1, -1, -1, -1)
reconstruction, attention_store, text_cross_attention_maps = sample(
zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps,prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
)
do_reconstruction = False
return (
reconstruction,
reconstruction,
gr.update(visible=True),
do_reconstruction,
"Hide Reconstruction",
)
def load_and_invert(
input_image,
do_inversion,
seed,
randomize_seed,
wts,
zs,
src_prompt="",
# tar_prompt="",
steps=30,
src_cfg_scale=3.5,
skip=15,
tar_cfg_scale=15,
progress=gr.Progress(track_tqdm=True),
):
# x0 = load_512(input_image, device=device).to(torch.float16)
if do_inversion or randomize_seed:
seed = randomize_seed_fn(seed, randomize_seed)
seed_everything(seed)
# invert and retrieve noise maps and latent
zs_tensor, wts_tensor = pipe.invert(
image_path=input_image,
source_prompt=src_prompt,
source_guidance_scale=src_cfg_scale,
num_inversion_steps=steps,
skip=skip,
eta=1.0,
)
wts = wts_tensor
zs = zs_tensor
do_inversion = False
return wts, zs, do_inversion, gr.update(visible=False)
## SEGA ##
def edit(input_image,
wts, zs, attention_store, text_cross_attention_maps,
tar_prompt,
image_caption,
steps,
skip,
tar_cfg_scale,
edit_concept_1,edit_concept_2,edit_concept_3,
guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
warmup_1, warmup_2, warmup_3,
neg_guidance_1, neg_guidance_2, neg_guidance_3,
threshold_1, threshold_2, threshold_3,
do_reconstruction,
reconstruction,
# for inversion in case it needs to be re computed (and avoid delay):
do_inversion,
seed,
randomize_seed,
src_prompt,
src_cfg_scale,
mask_type,
progress=gr.Progress(track_tqdm=True)):
show_share_button = gr.update(visible=True)
if(mask_type == "No mask"):
use_cross_attn_mask = False
use_intersect_mask = False
elif(mask_type=="Cross Attention Mask"):
use_cross_attn_mask = True
use_intersect_mask = False
elif(mask_type=="Intersect Mask"):
use_cross_attn_mask = False
use_intersect_mask = True
if randomize_seed:
seed = randomize_seed_fn(seed, randomize_seed)
seed_everything(seed)
if do_inversion or randomize_seed:
zs_tensor, wts_tensor = pipe.invert(
image_path = input_image,
source_prompt =src_prompt,
source_guidance_scale= src_cfg_scale,
num_inversion_steps = steps,
skip = skip,
eta = 1.0,
)
wts = wts_tensor
zs = zs_tensor
do_inversion = False
if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
tar_prompt = ""
if edit_concept_1 != "" or edit_concept_2 != "" or edit_concept_3 != "":
editing_args = dict(
editing_prompt = [edit_concept_1,edit_concept_2,edit_concept_3],
reverse_editing_direction = [ neg_guidance_1, neg_guidance_2, neg_guidance_3,],
edit_warmup_steps=[warmup_1, warmup_2, warmup_3,],
edit_guidance_scale=[guidnace_scale_1,guidnace_scale_2,guidnace_scale_3],
edit_threshold=[threshold_1, threshold_2, threshold_3],
edit_momentum_scale=0,
edit_mom_beta=0,
eta=1,
use_cross_attn_mask=use_cross_attn_mask,
use_intersect_mask=use_intersect_mask
)
latnets = wts[-1].expand(1, -1, -1, -1)
sega_out, attention_store, text_cross_attention_maps = pipe(prompt=tar_prompt,
init_latents=latnets,
guidance_scale = tar_cfg_scale,
# num_images_per_prompt=1,
# num_inference_steps=steps,
# use_ddpm=True,
# wts=wts.value,
zs=zs, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, **editing_args)
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
else: # if sega concepts were not added, performs regular ddpm sampling
if do_reconstruction: # if ddpm sampling wasn't computed
pure_ddpm_img, attention_store, text_cross_attention_maps = sample(zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
reconstruction = pure_ddpm_img
do_reconstruction = False
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
def randomize_seed_fn(seed, is_random):
if is_random:
seed = random.randint(0, np.iinfo(np.int32).max)
return seed
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
def crop_image(image):
h, w, c = image.shape
if h < w:
offset = (w - h) // 2
image = image[:, offset:offset + h]
elif w < h:
offset = (h - w) // 2
image = image[offset:offset + w]
image = np.array(Image.fromarray(image).resize((512, 512)))
return image
def get_example():
case = [
[
'examples/car_input.png',
# '',
'cherry blossom', 'green cabriolet','yellow car',
'examples/car_output.png',
13,11,7,
2,2,2,
False, False, True,
50,
25,
7.5,
0.65, 0.8, 0.8,
890000000
],
[
'examples/girl_with_pearl_earring_input.png',
# '',
'glasses', '','',
'examples/girl_with_pearl_earring_output.png',
4,7,0,
3,2,2,
False,False,False,
50,
25,
5,
0.97, 0.95,0.95,
1900000000
],
[
'examples/flower_field_input.jpg',
# '',
'pink tulips', 'red flowers',
'van gogh painting',
'examples/flower_field_output.png',
20,7,10,
1,1,1,
False,True,False,
50,
25,
7,
0.9, 0.9,0.8,
1900000000
],
]
return case
def swap_visibilities(input_image,
edit_concept_1,
edit_concept_2,
edit_concept_3,
sega_edited_image,
guidnace_scale_1,
guidnace_scale_2,
guidnace_scale_3,
warmup_1,
warmup_2,
warmup_3,
neg_guidance_1,
neg_guidance_2,
neg_guidance_3,
steps,
skip,
tar_cfg_scale,
threshold_1,
threshold_2,
threshold_3,
sega_concepts_counter
):
sega_concepts_counter=0
concept1_update = update_display_concept("Remove" if neg_guidance_1 else "Add", edit_concept_1, neg_guidance_1, sega_concepts_counter)
if(edit_concept_2 != ""):
concept2_update = update_display_concept("Remove" if neg_guidance_2 else "Add", edit_concept_2, neg_guidance_2, sega_concepts_counter+1)
else:
concept2_update = gr.update(visible=False), gr.update(visible=False),gr.update(visible=False), gr.update(value=neg_guidance_2),gr.update(visible=True),gr.update(visible=False),sega_concepts_counter+1
return (gr.update(visible=True), *concept1_update[:-1], *concept2_update)
########
# demo #
########
intro = """
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
LEDITS++: Limitless Image Editing using Text-to-Image Models
</h1>
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
<a href="https://leditsplusplus-project.static.hf.space" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2311.16711" target="_blank">paper</a>
|
<a href="https://huggingface.co/spaces/leditsplusplus/demo?duplicate=true" target="_blank" style="
display: inline-block;
">
<img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
</p>
"""
with gr.Blocks(css="style.css") as demo:
def update_counter(sega_concepts_counter, concept1, concept2, concept3):
if sega_concepts_counter == "":
sega_concepts_counter = sum(1 for concept in (concept1, concept2, concept3) if concept != '')
return sega_concepts_counter
def remove_concept(sega_concepts_counter, row_triggered):
sega_concepts_counter -= 1
rows_visibility = [gr.update(visible=False) for _ in range(4)]
if(row_triggered-1 > sega_concepts_counter):
rows_visibility[sega_concepts_counter] = gr.update(visible=True)
else:
rows_visibility[row_triggered-1] = gr.update(visible=True)
row1_visibility, row2_visibility, row3_visibility, row4_visibility = rows_visibility
guidance_scale_label = "Concept Guidance Scale"
# enable_interactive = gr.update(interactive=True)
return (gr.update(visible=False),
gr.update(visible=False, value="",),
gr.update(interactive=True, value=""),
gr.update(visible=False,label = guidance_scale_label),
gr.update(interactive=True, value =False),
gr.update(value=DEFAULT_WARMUP_STEPS),
gr.update(value=DEFAULT_THRESHOLD),
gr.update(visible=True),
gr.update(interactive=True, value="custom"),
row1_visibility,
row2_visibility,
row3_visibility,
row4_visibility,
sega_concepts_counter
)
def update_display_concept(button_label, edit_concept, neg_guidance, sega_concepts_counter):
sega_concepts_counter += 1
guidance_scale_label = "Concept Guidance Scale"
if(button_label=='Remove'):
neg_guidance = True
guidance_scale_label = "Negative Guidance Scale"
return (gr.update(visible=True), #boxn
gr.update(visible=True, value=edit_concept), #concept_n
gr.update(visible=True,label = guidance_scale_label), #guidance_scale_n
gr.update(value=neg_guidance),#neg_guidance_n
gr.update(visible=False), #row_n
gr.update(visible=True), #row_n+1
sega_concepts_counter
)
def display_editing_options(run_button, clear_button, sega_tab):
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def update_interactive_mode(add_button_label):
if add_button_label == "Clear":
return gr.update(interactive=False), gr.update(interactive=False)
else:
return gr.update(interactive=True), gr.update(interactive=True)
def update_dropdown_parms(dropdown):
if dropdown == 'custom':
return DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD
elif dropdown =='style':
return STYLE_SEGA_CONCEPT_GUIDANCE_SCALE,STYLE_WARMUP_STEPS, STYLE_THRESHOLD
elif dropdown =='object':
return OBJECT_SEGA_CONCEPT_GUIDANCE_SCALE,OBJECT_WARMUP_STEPS, OBJECT_THRESHOLD
elif dropdown =='faces':
return FACE_SEGA_CONCEPT_GUIDANCE_SCALE,FACE_WARMUP_STEPS, FACE_THRESHOLD
def reset_do_inversion():
return True
def reset_do_reconstruction():
do_reconstruction = True
return do_reconstruction
def reset_image_caption():
return ""
def update_inversion_progress_visibility(input_image, do_inversion):
if do_inversion and not input_image is None:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def update_edit_progress_visibility(input_image, do_inversion):
# if do_inversion and not input_image is None:
# return inversion_progress.update(visible=True)
# else:
return gr.update(visible=True)
gr.HTML(intro)
wts = gr.State()
zs = gr.State()
attention_store=gr.State()
text_cross_attention_maps = gr.State()
reconstruction = gr.State()
do_inversion = gr.State(value=True)
do_reconstruction = gr.State(value=True)
sega_concepts_counter = gr.State(0)
image_caption = gr.State(value="")
with gr.Row():
input_image = gr.Image(label="Input Image", interactive=True, elem_id="input_image")
ddpm_edited_image = gr.Image(label=f"Pure DDPM Inversion Image", interactive=False, visible=False)
sega_edited_image = gr.Image(label=f"LEDITS Edited Image", interactive=False, elem_id="output_image")
with gr.Group(visible=False, elem_id="share-btn-wrapper") as share_btn_container:
with gr.Group(elem_id="share-btn-container"):
community_icon = gr.HTML(community_icon_html, visible=True)
loading_icon = gr.HTML(loading_icon_html, visible=False)
share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
with gr.Row():
with gr.Group(visible=False, elem_id="box1") as box1:
with gr.Row():
concept_1 = gr.Button(scale=3, value="")
remove_concept1 = gr.Button("x", scale=1, min_width=10)
with gr.Row():
guidnace_scale_1 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
info="How strongly the concept should modify the image",
value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
step=0.5, interactive=True)
with gr.Group(visible=False, elem_id="box2") as box2:
with gr.Row():
concept_2 = gr.Button(scale=3, value="")
remove_concept2 = gr.Button("x", scale=1, min_width=10)
with gr.Row():
guidnace_scale_2 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
info="How strongly the concept should modify the image",
value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
step=0.5, interactive=True)
with gr.Group(visible=False, elem_id="box3") as box3:
with gr.Row():
concept_3 = gr.Button(scale=3, value="")
remove_concept3 = gr.Button("x", scale=1, min_width=10)
with gr.Row():
guidnace_scale_3 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
info="How strongly the concept should modify the image",
value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
step=0.5, interactive=True)
with gr.Row():
inversion_progress = gr.Textbox(visible=False, label="Inversion progress")
with gr.Group():
intro_segs = gr.Markdown("Add/Remove Concepts from your Image <span style=\"font-size: 12px; color: rgb(156, 163, 175)\">with Semantic Guidance</span>")
# 1st SEGA concept
with gr.Row() as row1:
with gr.Column(scale=3, min_width=100):
with gr.Row():
# with gr.Column(scale=3, min_width=100):
edit_concept_1 = gr.Textbox(
label="Concept",
show_label=True,
max_lines=1, value="",
placeholder="E.g.: Sunglasses",
)
# with gr.Column(scale=2, min_width=100):# better mobile ui
dropdown1 = gr.Dropdown(label = "Edit Type", value ='custom' , choices=['custom','style', 'object', 'faces'])
with gr.Column(scale=1, min_width=100, visible=False):
neg_guidance_1 = gr.Checkbox(
label='Remove Concept?')
with gr.Column(scale=1, min_width=100):
with gr.Row(): # better mobile ui
with gr.Column():
add_1 = gr.Button('Add')
remove_1 = gr.Button('Remove')
# 2nd SEGA concept
with gr.Row(visible=False) as row2:
with gr.Column(scale=3, min_width=100):
with gr.Row(): #better mobile UI
# with gr.Column(scale=3, min_width=100):
edit_concept_2 = gr.Textbox(
label="Concept",
show_label=True,
max_lines=1,
placeholder="E.g.: Realistic",
)
# with gr.Column(scale=2, min_width=100):# better mobile ui
dropdown2 = gr.Dropdown(label = "Edit Type", value ='custom' , choices=['custom','style', 'object', 'faces'])
with gr.Column(scale=1, min_width=100, visible=False):
neg_guidance_2 = gr.Checkbox(
label='Remove Concept?')
with gr.Column(scale=1, min_width=100):
with gr.Row(): # better mobile ui
with gr.Column():
add_2 = gr.Button('Add')
remove_2 = gr.Button('Remove')
# 3rd SEGA concept
with gr.Row(visible=False) as row3:
with gr.Column(scale=3, min_width=100):
with gr.Row(): #better mobile UI
# with gr.Column(scale=3, min_width=100):
edit_concept_3 = gr.Textbox(
label="Concept",
show_label=True,
max_lines=1,
placeholder="E.g.: orange",
)
# with gr.Column(scale=2, min_width=100):
dropdown3 = gr.Dropdown(label = "Edit Type", value ='custom' , choices=['custom','style', 'object', 'faces'])
with gr.Column(scale=1, min_width=100, visible=False):
neg_guidance_3 = gr.Checkbox(
label='Remove Concept?',visible=True)
with gr.Column(scale=1, min_width=100):
with gr.Row(): # better mobile ui
with gr.Column():
add_3 = gr.Button('Add')
remove_3 = gr.Button('Remove')
with gr.Row(visible=False) as row4:
gr.Markdown("### Max of 3 concepts reached. Remove a concept to add more")
#with gr.Row(visible=False).style(mobile_collapse=False, equal_height=True):
# add_concept_button = gr.Button("+1 concept")
# caption_button = gr.Button("Caption Image", scale=1)
with gr.Row():
run_button = gr.Button("Edit your image!", visible=True)
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
tar_prompt = gr.Textbox(
label="Describe your edited image (optional)",
elem_id="target_prompt",
# show_label=False,
max_lines=1, value="", scale=3,
placeholder="Target prompt, DDPM Inversion", info = "DPM Solver++ Inversion Prompt. Can help with global changes, modify to what you would like to see"
)
with gr.Tabs() as tabs:
with gr.TabItem('General options', id=2):
with gr.Row():
with gr.Column(min_width=100):
clear_button = gr.Button("Clear", visible=True)
src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True, placeholder="")
steps = gr.Number(value=50, precision=0, label="Num Diffusion Steps", interactive=True)
src_cfg_scale = gr.Number(value=3.5, label=f"Source Guidance Scale", interactive=True)
mask_type = gr.Radio(choices=["No mask", "Cross Attention Mask", "Intersect Mask"], value="Intersect Mask", label="Mask type")
with gr.Column(min_width=100):
reconstruct_button = gr.Button("Show Reconstruction", visible=False)
skip = gr.Slider(minimum=0, maximum=95, value=25, step=1, label="Skip Steps", interactive=True, info = "Percentage of skipped denoising steps. Bigger values increase fidelity to input image")
tar_cfg_scale = gr.Slider(minimum=1, maximum=30,value=7.5, label=f"Guidance Scale", interactive=True)
seed = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
with gr.TabItem('SEGA options', id=3) as sega_advanced_tab:
# 1st SEGA concept
gr.Markdown("1st concept")
with gr.Row():
warmup_1 = gr.Slider(label='Warmup', minimum=0, maximum=50,
value=DEFAULT_WARMUP_STEPS,
step=1, interactive=True, info="At which step to start applying semantic guidance. Bigger values reduce edit concept's effect")
threshold_1 = gr.Slider(label='Threshold', minimum=0, maximum=0.99,
value=DEFAULT_THRESHOLD, step=0.01, interactive=True,
info = "Lower the threshold for more effect (e.g. ~0.9 for style transfer)")
# 2nd SEGA concept
gr.Markdown("2nd concept")
with gr.Row() as row2_advanced:
warmup_2 = gr.Slider(label='Warmup', minimum=0, maximum=50,
value=DEFAULT_WARMUP_STEPS,
step=1, interactive=True, info="At which step to start applying semantic guidance. Bigger values reduce edit concept's effect")
threshold_2 = gr.Slider(label='Threshold', minimum=0, maximum=0.99,
value=DEFAULT_THRESHOLD,
step=0.01, interactive=True,
info = "Lower the threshold for more effect (e.g. ~0.9 for style transfer)")
# 3rd SEGA concept
gr.Markdown("3rd concept")
with gr.Row() as row3_advanced:
warmup_3 = gr.Slider(label='Warmup', minimum=0, maximum=50,
value=DEFAULT_WARMUP_STEPS, step=1,
interactive=True, info="At which step to start applying semantic guidance. Bigger values reduce edit concept's effect")
threshold_3 = gr.Slider(label='Threshold', minimum=0, maximum=0.99,
value=DEFAULT_THRESHOLD, step=0.01,
interactive=True,
info = "Lower the threshold for more effect (e.g. ~0.9 for style transfer)")
# caption_button.click(
# fn = caption_image,
# inputs = [input_image],
# outputs = [tar_prompt]
# )
#neg_guidance_1.change(fn = update_label, inputs=[neg_guidance_1], outputs=[add_1])
#neg_guidance_2.change(fn = update_label, inputs=[neg_guidance_2], outputs=[add_2])
#neg_guidance_3.change(fn = update_label, inputs=[neg_guidance_3], outputs=[add_3])
add_1.click(fn=update_counter,
inputs=[sega_concepts_counter,edit_concept_1,edit_concept_2,edit_concept_3],
outputs=sega_concepts_counter,queue=False).then(fn = update_display_concept, inputs=[add_1, edit_concept_1, neg_guidance_1, sega_concepts_counter], outputs=[box1, concept_1, guidnace_scale_1,neg_guidance_1,row1, row2, sega_concepts_counter],queue=False)
add_2.click(fn=update_counter,inputs=[sega_concepts_counter,edit_concept_1,edit_concept_2,edit_concept_3], outputs=sega_concepts_counter,queue=False).then(fn = update_display_concept, inputs=[add_2, edit_concept_2, neg_guidance_2, sega_concepts_counter], outputs=[box2, concept_2, guidnace_scale_2,neg_guidance_2,row2, row3, sega_concepts_counter],queue=False)
add_3.click(fn=update_counter,inputs=[sega_concepts_counter,edit_concept_1,edit_concept_2,edit_concept_3], outputs=sega_concepts_counter,queue=False).then(fn = update_display_concept, inputs=[add_3, edit_concept_3, neg_guidance_3, sega_concepts_counter], outputs=[box3, concept_3, guidnace_scale_3,neg_guidance_3,row3, row4, sega_concepts_counter],queue=False)
remove_1.click(fn = update_display_concept, inputs=[remove_1, edit_concept_1, neg_guidance_1, sega_concepts_counter], outputs=[box1, concept_1, guidnace_scale_1,neg_guidance_1,row1, row2, sega_concepts_counter],queue=False)
remove_2.click(fn = update_display_concept, inputs=[remove_2, edit_concept_2, neg_guidance_2 ,sega_concepts_counter], outputs=[box2, concept_2, guidnace_scale_2,neg_guidance_2,row2, row3,sega_concepts_counter],queue=False)
remove_3.click(fn = update_display_concept, inputs=[remove_3, edit_concept_3, neg_guidance_3, sega_concepts_counter], outputs=[box3, concept_3, guidnace_scale_3,neg_guidance_3, row3, row4, sega_concepts_counter],queue=False)
remove_concept1.click(
fn=update_counter,inputs=[sega_concepts_counter,edit_concept_1,edit_concept_2,edit_concept_3], outputs=sega_concepts_counter,queue=False).then(
fn = remove_concept, inputs=[sega_concepts_counter,gr.State(1)], outputs= [box1, concept_1, edit_concept_1, guidnace_scale_1,neg_guidance_1,warmup_1, threshold_1, add_1, dropdown1, row1, row2, row3, row4, sega_concepts_counter],queue=False)
remove_concept2.click(
fn=update_counter,inputs=[sega_concepts_counter,edit_concept_1,edit_concept_2,edit_concept_3], outputs=sega_concepts_counter,queue=False).then(
fn = remove_concept, inputs=[sega_concepts_counter,gr.State(2)], outputs=[box2, concept_2, edit_concept_2, guidnace_scale_2,neg_guidance_2, warmup_2, threshold_2, add_2 , dropdown2, row1, row2, row3, row4, sega_concepts_counter],queue=False)
remove_concept3.click(
fn=update_counter,inputs=[sega_concepts_counter,edit_concept_1,edit_concept_2,edit_concept_3], outputs=sega_concepts_counter,queue=False).then(
fn = remove_concept,inputs=[sega_concepts_counter,gr.State(3)], outputs=[box3, concept_3, edit_concept_3, guidnace_scale_3,neg_guidance_3,warmup_3, threshold_3, add_3, dropdown3, row1, row2, row3, row4, sega_concepts_counter],queue=False)
#add_concept_button.click(fn = update_display_concept, inputs=sega_concepts_counter,
# outputs= [row2, row2_advanced, row3, row3_advanced, add_concept_button, sega_concepts_counter], queue = False)
run_button.click(
fn=edit,
inputs=[input_image,
wts, zs, attention_store,
text_cross_attention_maps,
tar_prompt,
image_caption,
steps,
skip,
tar_cfg_scale,
edit_concept_1,edit_concept_2,edit_concept_3,
guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
warmup_1, warmup_2, warmup_3,
neg_guidance_1, neg_guidance_2, neg_guidance_3,
threshold_1, threshold_2, threshold_3, do_reconstruction, reconstruction,
do_inversion,
seed,
randomize_seed,
src_prompt,
src_cfg_scale,
mask_type
],
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, text_cross_attention_maps, do_inversion, share_btn_container])
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
input_image.change(
fn = reset_do_inversion,
outputs = [do_inversion],
queue = False).then(
fn = randomize_seed_fn,
inputs = [seed, randomize_seed],
outputs = [seed], queue = False)
# Automatically start inverting upon input_image change
input_image.upload(fn = crop_image, inputs = [input_image], outputs = [input_image],queue=False).then(
fn = reset_do_inversion,
outputs = [do_inversion],
queue = False).then(
fn = randomize_seed_fn,
inputs = [seed, randomize_seed],
outputs = [seed], queue = False).then(fn = caption_image,
inputs = [input_image],
outputs = [tar_prompt, image_caption]).then(fn = update_inversion_progress_visibility, inputs =[input_image,do_inversion],
outputs=[inversion_progress],queue=False).then(
fn=load_and_invert,
inputs=[input_image,
do_inversion,
seed, randomize_seed,
wts, zs,
src_prompt,
# tar_prompt,
steps,
src_cfg_scale,
skip,
tar_cfg_scale,
],
# outputs=[ddpm_edited_image, wts, zs, do_inversion],
outputs=[wts, zs, do_inversion, inversion_progress],
).then(fn = update_inversion_progress_visibility, inputs =[input_image,do_inversion],
outputs=[inversion_progress],queue=False).then(
lambda: gr.update(visible=False),
outputs=[reconstruct_button]).then(
fn = reset_do_reconstruction,
outputs = [do_reconstruction],
queue = False)
# Repeat inversion (and reconstruction) when these params are changed:
src_prompt.change(
fn = reset_do_inversion,
outputs = [do_inversion], queue = False).then(
fn = reset_do_reconstruction,
outputs = [do_reconstruction], queue = False)
steps.change(
fn = reset_do_inversion,
outputs = [do_inversion], queue = False).then(
fn = reset_do_reconstruction,
outputs = [do_reconstruction], queue = False)
src_cfg_scale.change(
fn = reset_do_inversion,
outputs = [do_inversion], queue = False).then(
fn = reset_do_reconstruction,
outputs = [do_reconstruction], queue = False)
# Repeat only reconstruction these params are changed:
tar_prompt.change(
fn = reset_do_reconstruction,
outputs = [do_reconstruction], queue = False)
tar_cfg_scale.change(
fn = reset_do_reconstruction,
outputs = [do_reconstruction], queue = False)
skip.change(
fn = reset_do_inversion,
outputs = [do_inversion], queue = False).then(
fn = reset_do_reconstruction,
outputs = [do_reconstruction], queue = False)
seed.change(fn=reset_do_inversion, outputs=[do_inversion], queue=False).then(
fn=reset_do_reconstruction, outputs=[do_reconstruction], queue=False
)
dropdown1.change(fn=update_dropdown_parms, inputs = [dropdown1], outputs = [guidnace_scale_1,warmup_1, threshold_1], queue=False)
dropdown2.change(fn=update_dropdown_parms, inputs = [dropdown2], outputs = [guidnace_scale_2,warmup_2, threshold_2], queue=False)
dropdown3.change(fn=update_dropdown_parms, inputs = [dropdown3], outputs = [guidnace_scale_3,warmup_3, threshold_3], queue=False)
clear_components = [input_image,ddpm_edited_image,ddpm_edited_image,sega_edited_image, do_inversion,
src_prompt, steps, src_cfg_scale, seed,
tar_prompt, skip, tar_cfg_scale, reconstruct_button,reconstruct_button,
edit_concept_1, guidnace_scale_1,guidnace_scale_1,warmup_1, threshold_1, neg_guidance_1,dropdown1, concept_1, concept_1, row1,
edit_concept_2, guidnace_scale_2,guidnace_scale_2,warmup_2, threshold_2, neg_guidance_2,dropdown2, concept_2, concept_2, row2,
edit_concept_3, guidnace_scale_3,guidnace_scale_3,warmup_3, threshold_3, neg_guidance_3,dropdown3, concept_3,concept_3, row3,
row4,sega_concepts_counter, box1, box2, box3 ]
clear_components_output_vals = [None, None,gr.update(visible=False), None, True,
"", DEFAULT_DIFFUSION_STEPS, DEFAULT_SOURCE_GUIDANCE_SCALE, DEFAULT_SEED,
"", DEFAULT_SKIP_STEPS, DEFAULT_TARGET_GUIDANCE_SCALE, gr.update(value="Show Reconstruction"),gr.update(visible=False),
"", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,gr.update(visible=False), DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE, "custom","", gr.update(visible=False), gr.update(visible=True),
"", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,gr.update(visible=False), DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE, "custom","", gr.update(visible=False), gr.update(visible=False),
"", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,gr.update(visible=False), DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE, "custom","",gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=0),
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
clear_button.click(lambda: clear_components_output_vals, outputs = clear_components)
reconstruct_button.click(lambda: ddpm_edited_image.update(visible=True), outputs=[ddpm_edited_image]).then(fn = reconstruct,
inputs = [tar_prompt,
image_caption,
tar_cfg_scale,
skip,
wts, zs,
do_reconstruction,
reconstruction,
reconstruct_button],
outputs = [ddpm_edited_image,reconstruction, ddpm_edited_image, do_reconstruction, reconstruct_button])
randomize_seed.change(
fn = randomize_seed_fn,
inputs = [seed, randomize_seed],
outputs = [seed],
queue = False)
share_button.click(None, [], [], js=share_js)
gr.Examples(
label='Examples',
fn=swap_visibilities,
run_on_click=True,
examples=get_example(),
inputs=[input_image,
edit_concept_1,
edit_concept_2,
edit_concept_3,
sega_edited_image,
guidnace_scale_1,
guidnace_scale_2,
guidnace_scale_3,
warmup_1,
warmup_2,
warmup_3,
neg_guidance_1,
neg_guidance_2,
neg_guidance_3,
steps,
skip,
tar_cfg_scale,
threshold_1,
threshold_2,
threshold_3,
seed,
sega_concepts_counter
],
outputs=[share_btn_container, box1, concept_1, guidnace_scale_1,neg_guidance_1, row1, row2,box2, concept_2, guidnace_scale_2,neg_guidance_2,row2, row3,sega_concepts_counter],
cache_examples=True
)
demo.queue()
demo.launch()