EditAnything / editany_demo.py
shgao's picture
update new demo
0c7479d
# Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
import gradio as gr
def create_demo_template(
process,
process_image_click=None,
examples=None,
INFO="EditAnything https://github.com/sail-sg/EditAnything",
WARNING_INFO=None,
enable_auto_prompt_default=False,
):
print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
block = gr.Blocks()
with block as demo:
clicked_points = gr.State([])
origin_image = gr.State(None)
click_mask = gr.State(None)
ref_clicked_points = gr.State([])
ref_origin_image = gr.State(None)
ref_click_mask = gr.State(None)
with gr.Row():
gr.Markdown(INFO)
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Tab("Click🖱"):
source_image_click = gr.Image(
type="pil",
interactive=True,
label="Image: Upload an image and click the region you want to edit.",
)
with gr.Column():
with gr.Row():
point_prompt = gr.Radio(
choices=["Foreground Point",
"Background Point"],
value="Foreground Point",
label="Point Label",
interactive=True,
show_label=False,
)
clear_button_click = gr.Button(
value="Clear Click Points", interactive=True
)
clear_button_image = gr.Button(
value="Clear Image", interactive=True
)
with gr.Row():
run_button_click = gr.Button(
label="Run EditAnying", interactive=True
)
with gr.Tab("Brush🖌️"):
source_image_brush = gr.Image(
source="upload",
label="Image: Upload an image and cover the region you want to edit with sketch",
type="numpy",
tool="sketch",
)
run_button = gr.Button(
label="Run EditAnying", interactive=True)
with gr.Column():
enable_all_generate = gr.Checkbox(
label="Auto generation on all region.", value=False
)
control_scale = gr.Slider(
label="Mask Align strength",
info="Large value -> strict alignment with SAM mask",
minimum=0,
maximum=1,
value=0.5,
step=0.1,
)
with gr.Column():
enable_auto_prompt = gr.Checkbox(
label="Auto generate text prompt from input image with BLIP2",
info="Warning: Enable this may makes your prompt not working.",
value=enable_auto_prompt_default,
)
a_prompt = gr.Textbox(
label="Positive Prompt",
info="Text in the expected things of edited region",
value="best quality, extremely detailed,",
)
n_prompt = gr.Textbox(
label="Negative Prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, NSFW",
)
with gr.Row():
num_samples = gr.Slider(
label="Images", minimum=1, maximum=12, value=2, step=1
)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
randomize=True,
)
with gr.Row():
enable_tile = gr.Checkbox(
label="Tile refinement for high resolution generation",
info="Slow inference",
value=True,
)
refine_alignment_ratio = gr.Slider(
label="Alignment Strength",
info="Large value -> strict alignment with input image. Small value -> strong global consistency",
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.05,
)
with gr.Accordion("Reference options", open=False):
# ref_image = gr.Image(
# source='upload', label="Upload a reference image", type="pil", value=None)
ref_image = gr.Image(
source="upload",
label="Upload a reference image and cover the region you want to use with sketch",
type="pil",
tool="sketch",
)
with gr.Column():
ref_auto_prompt = gr.Checkbox(
label="Ref. Auto Prompt", value=True
)
ref_prompt = gr.Textbox(
label="Prompt",
info="Text in the prompt of edited region",
value="best quality, extremely detailed, ",
)
# ref_image = gr.Image(
# type="pil", interactive=True,
# label="Image: Upload an image and click the region you want to use as reference.",
# )
# with gr.Column():
# with gr.Row():
# ref_point_prompt = gr.Radio(
# choices=["Foreground Point", "Background Point"],
# value="Foreground Point",
# label="Point Label",
# interactive=True, show_label=False)
# ref_clear_button_click = gr.Button(
# value="Clear Click Points", interactive=True)
# ref_clear_button_image = gr.Button(
# value="Clear Image", interactive=True)
with gr.Row():
reference_attn = gr.Checkbox(
label="reference_attn", value=True)
attention_auto_machine_weight = gr.Slider(
label="attention_weight",
minimum=0,
maximum=1.0,
value=0.8,
step=0.01,
)
with gr.Row():
reference_adain = gr.Checkbox(
label="reference_adain", value=False
)
gn_auto_machine_weight = gr.Slider(
label="gn_weight",
minimum=0,
maximum=1.0,
value=0.1,
step=0.01,
)
style_fidelity = gr.Slider(
label="Style fidelity",
minimum=0,
maximum=1.0,
value=0.5,
step=0.01,
)
ref_sam_scale = gr.Slider(
label="SAM Control Scale",
minimum=0,
maximum=1.0,
value=0.3,
step=0.1,
)
ref_inpaint_scale = gr.Slider(
label="Inpaint Control Scale",
minimum=0,
maximum=1.0,
value=0.2,
step=0.1,
)
with gr.Row():
ref_textinv = gr.Checkbox(
label="Use textual inversion token", value=False
)
ref_textinv_path = gr.Textbox(
label="textual inversion token path",
info="Text in the inversion token path",
value=None,
)
with gr.Accordion("Advanced options", open=False):
mask_image = gr.Image(
source="upload",
label="Upload a predefined mask of edit region: Switch to Brush mode when using this!",
type="numpy",
value=None,
)
image_resolution = gr.Slider(
label="Image Resolution",
minimum=256,
maximum=768,
value=512,
step=64,
)
refine_image_resolution = gr.Slider(
label="Image Resolution",
minimum=256,
maximum=8192,
value=1024,
step=64,
)
guess_mode = gr.Checkbox(label="Guess Mode", value=False)
detect_resolution = gr.Slider(
label="SAM Resolution",
minimum=128,
maximum=2048,
value=1024,
step=1,
)
ddim_steps = gr.Slider(
label="Steps", minimum=1, maximum=100, value=30, step=1
)
scale = gr.Slider(
label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
alpha_weight = gr.Slider(
label="Alpha weight", info="Alpha mixing with original image", minimum=0,
maximum=1, value=0.0, step=0.1)
use_scale_map = gr.Checkbox(
label='Use scale map', value=False)
eta = gr.Number(label="eta (DDIM)", value=0.0)
condition_model = gr.Textbox(
label="Condition model path",
info="Text in the Controlnet model path in hugglingface",
value="EditAnything",
)
with gr.Column():
result_gallery_refine = gr.Gallery(
label="Output High quality", show_label=True, elem_id="gallery"
).style(grid=2, preview=False)
result_gallery_init = gr.Gallery(
label="Output Low quality", show_label=True, elem_id="gallery"
).style(grid=2, height="auto")
result_gallery_ref = gr.Gallery(
label="Output Ref", show_label=False, elem_id="gallery"
).style(grid=2, height="auto")
result_text = gr.Text(label="BLIP2+Human Prompt Text")
ips = [
source_image_brush,
enable_all_generate,
mask_image,
control_scale,
enable_auto_prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
detect_resolution,
ddim_steps,
guess_mode,
scale,
seed,
eta,
enable_tile,
refine_alignment_ratio,
refine_image_resolution,
alpha_weight,
use_scale_map,
condition_model,
ref_image,
attention_auto_machine_weight,
gn_auto_machine_weight,
style_fidelity,
reference_attn,
reference_adain,
ref_prompt,
ref_sam_scale,
ref_inpaint_scale,
ref_auto_prompt,
ref_textinv,
ref_textinv_path,
]
run_button.click(
fn=process,
inputs=ips,
outputs=[
result_gallery_refine,
result_gallery_init,
result_gallery_ref,
result_text,
],
)
ip_click = [
origin_image,
enable_all_generate,
click_mask,
control_scale,
enable_auto_prompt,
a_prompt,
n_prompt,
num_samples,
image_resolution,
detect_resolution,
ddim_steps,
guess_mode,
scale,
seed,
eta,
enable_tile,
refine_alignment_ratio,
refine_image_resolution,
alpha_weight,
use_scale_map,
condition_model,
ref_image,
attention_auto_machine_weight,
gn_auto_machine_weight,
style_fidelity,
reference_attn,
reference_adain,
ref_prompt,
ref_sam_scale,
ref_inpaint_scale,
ref_auto_prompt,
ref_textinv,
ref_textinv_path,
]
run_button_click.click(
fn=process,
inputs=ip_click,
outputs=[
result_gallery_refine,
result_gallery_init,
result_gallery_ref,
result_text,
],
)
source_image_click.upload(
lambda image: image.copy() if image is not None else None,
inputs=[source_image_click],
outputs=[origin_image],
)
source_image_click.select(
process_image_click,
inputs=[origin_image, point_prompt,
clicked_points, image_resolution],
outputs=[source_image_click, clicked_points, click_mask],
show_progress=True,
queue=True,
)
clear_button_click.click(
fn=lambda original_image: (original_image.copy(), [], None)
if original_image is not None
else (None, [], None),
inputs=[origin_image],
outputs=[source_image_click, clicked_points, click_mask],
)
clear_button_image.click(
fn=lambda: (None, [], None, None, None),
inputs=[],
outputs=[
source_image_click,
clicked_points,
click_mask,
result_gallery_init,
result_text,
],
)
if examples is not None:
with gr.Row():
ex = gr.Examples(
examples=examples,
fn=process,
inputs=[a_prompt, n_prompt, scale],
outputs=[result_gallery_init],
cache_examples=False,
)
if WARNING_INFO is not None:
with gr.Row():
gr.Markdown(WARNING_INFO)
return demo