DragDiffusion / drag_ui.py
peter850421's picture
Upload folder using huggingface_hub
e1ebf71
# *************************************************************************
# Copyright (2023) Bytedance Inc.
#
# Copyright (2023) DragDiffusion Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# *************************************************************************
import os
import gradio as gr
from utils.ui_utils import get_points, undo_points
from utils.ui_utils import clear_all, store_img, train_lora_interface, run_drag
from utils.ui_utils import clear_all_gen, store_img_gen, gen_img, run_drag_gen
LENGTH=480 # length of the square area displaying/editing images
with gr.Blocks() as demo:
# layout definition
with gr.Row():
gr.Markdown("""
# Official Implementation of [DragDiffusion](https://arxiv.org/abs/2306.14435)
""")
# UI components for editing real images
with gr.Tab(label="Editing Real Image"):
mask = gr.State(value=None) # store mask
selected_points = gr.State([]) # store points
original_image = gr.State(value=None) # store original input image
with gr.Row():
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
canvas = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
show_label=True, height=LENGTH, width=LENGTH) # for mask painting
train_lora_button = gr.Button("Train LoRA")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
input_image = gr.Image(type="numpy", label="Click Points",
show_label=True, height=LENGTH, width=LENGTH) # for points clicking
undo_button = gr.Button("Undo point")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
output_image = gr.Image(type="numpy", label="Editing Results",
show_label=True, height=LENGTH, width=LENGTH)
with gr.Row():
run_button = gr.Button("Run")
clear_all_button = gr.Button("Clear All")
# general parameters
with gr.Row():
prompt = gr.Textbox(label="Prompt")
lora_path = gr.Textbox(value="./lora_tmp", label="LoRA path")
lora_status_bar = gr.Textbox(label="display LoRA training status")
# algorithm specific parameters
with gr.Tab("Drag Config"):
with gr.Row():
n_pix_step = gr.Number(
value=40,
label="number of pixel steps",
info="Number of gradient descent (motion supervision) steps on latent.",
precision=0)
lam = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas")
# n_actual_inference_step = gr.Number(value=40, label="optimize latent step", precision=0)
inversion_strength = gr.Slider(0, 1.0,
value=0.75,
label="inversion strength",
info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
latent_lr = gr.Number(value=0.01, label="latent lr")
start_step = gr.Number(value=0, label="start_step", precision=0, visible=False)
start_layer = gr.Number(value=10, label="start_layer", precision=0, visible=False)
with gr.Tab("Base Model Config"):
with gr.Row():
local_models_dir = 'local_pretrained_models'
local_models_choice = \
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
label="Diffusion Model Path",
choices=[
"runwayml/stable-diffusion-v1-5",
] + local_models_choice
)
vae_path = gr.Dropdown(value="default",
label="VAE choice",
choices=["default",
"stabilityai/sd-vae-ft-mse"] + local_models_choice
)
with gr.Tab("LoRA Parameters"):
with gr.Row():
lora_step = gr.Number(value=200, label="LoRA training steps", precision=0)
lora_lr = gr.Number(value=0.0002, label="LoRA learning rate")
lora_rank = gr.Number(value=16, label="LoRA rank", precision=0)
# UI components for editing generated images
with gr.Tab(label="Editing Generated Image"):
mask_gen = gr.State(value=None) # store mask
selected_points_gen = gr.State([]) # store points
original_image_gen = gr.State(value=None) # store the diffusion-generated image
intermediate_latents_gen = gr.State(value=None) # store the intermediate diffusion latent during generation
with gr.Row():
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
canvas_gen = gr.Image(type="numpy", tool="sketch", label="Draw Mask",
show_label=True, height=LENGTH, width=LENGTH) # for mask painting
gen_img_button = gr.Button("Generate Image")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click Points</p>""")
input_image_gen = gr.Image(type="numpy", label="Click Points",
show_label=True, height=LENGTH, width=LENGTH) # for points clicking
undo_button_gen = gr.Button("Undo point")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing Results</p>""")
output_image_gen = gr.Image(type="numpy", label="Editing Results",
show_label=True, height=LENGTH, width=LENGTH)
with gr.Row():
run_button_gen = gr.Button("Run")
clear_all_button_gen = gr.Button("Clear All")
# general parameters
with gr.Row():
pos_prompt_gen = gr.Textbox(label="Positive Prompt")
neg_prompt_gen = gr.Textbox(label="Negative Prompt")
with gr.Tab("Generation Config"):
with gr.Row():
local_models_dir = 'local_pretrained_models'
local_models_choice = \
[os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
model_path_gen = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
label="Diffusion Model Path",
choices=[
"runwayml/stable-diffusion-v1-5",
"gsdf/Counterfeit-V2.5",
"emilianJR/majicMIX_realistic",
"SG161222/Realistic_Vision_V2.0",
"stablediffusionapi/landscapesupermix",
"huangzhe0803/ArchitectureRealMix",
"stablediffusionapi/interiordesignsuperm"
] + local_models_choice
)
vae_path_gen = gr.Dropdown(value="default",
label="VAE choice",
choices=["default",
"stabilityai/sd-vae-ft-mse"] + local_models_choice
)
lora_path_gen = gr.Textbox(value="", label="LoRA path")
gen_seed = gr.Number(value=65536, label="Generation Seed", precision=0)
height = gr.Number(value=512, label="Height", precision=0)
width = gr.Number(value=512, label="Width", precision=0)
guidance_scale = gr.Number(value=7.5, label="CFG Scale")
scheduler_name_gen = gr.Dropdown(
value="DDIM",
label="Scheduler",
choices=[
"DDIM",
"DPM++2M",
"DPM++2M_karras"
]
)
n_inference_step_gen = gr.Number(value=50, label="Total Sampling Steps", precision=0)
with gr.Tab(label="Drag Config"):
with gr.Row():
n_pix_step_gen = gr.Number(
value=40,
label="Number of Pixel Steps",
info="Number of gradient descent (motion supervision) steps on latent.",
precision=0)
lam_gen = gr.Number(value=0.1, label="lam", info="regularization strength on unmasked areas")
# n_actual_inference_step_gen = gr.Number(value=40, label="optimize latent step", precision=0)
inversion_strength_gen = gr.Slider(0, 1.0,
value=0.75,
label="Inversion Strength",
info="The latent at [inversion-strength * total-sampling-steps] is optimized for dragging.")
latent_lr_gen = gr.Number(value=0.01, label="latent lr")
start_step_gen = gr.Number(value=0, label="start_step", precision=0, visible=False)
start_layer_gen = gr.Number(value=10, label="start_layer", precision=0, visible=False)
# Add a checkbox for users to select if they want a GIF of the process
with gr.Row():
create_gif_checkbox = gr.Checkbox(label="create_GIF", value=False)
create_tracking_point_checkbox = gr.Checkbox(label="create_tracking_point", value=False)
gif_interval = gr.Number(value=10, label="interval_GIF", precision=0, info="The interval of the GIF, i.e. the number of steps between each frame of the GIF.")
gif_fps = gr.Number(value=1, label="fps_GIF", precision=0, info="The fps of the GIF, i.e. the number of frames per second of the GIF.")
# event definition
# event for dragging user-input real image
canvas.edit(
store_img,
[canvas],
[original_image, selected_points, input_image, mask]
)
input_image.select(
get_points,
[input_image, selected_points],
[input_image],
)
undo_button.click(
undo_points,
[original_image, mask],
[input_image, selected_points]
)
train_lora_button.click(
train_lora_interface,
[original_image,
prompt,
model_path,
vae_path,
lora_path,
lora_step,
lora_lr,
lora_rank],
[lora_status_bar]
)
run_button.click(
run_drag,
[original_image,
input_image,
mask,
prompt,
selected_points,
inversion_strength,
lam,
latent_lr,
n_pix_step,
model_path,
vae_path,
lora_path,
start_step,
start_layer,
create_gif_checkbox,
gif_interval,
],
[output_image]
)
clear_all_button.click(
clear_all,
[gr.Number(value=LENGTH, visible=False, precision=0)],
[canvas,
input_image,
output_image,
selected_points,
original_image,
mask]
)
# event for dragging generated image
canvas_gen.edit(
store_img_gen,
[canvas_gen],
[original_image_gen, selected_points_gen, input_image_gen, mask_gen]
)
input_image_gen.select(
get_points,
[input_image_gen, selected_points_gen],
[input_image_gen],
)
gen_img_button.click(
gen_img,
[
gr.Number(value=LENGTH, visible=False, precision=0),
height,
width,
n_inference_step_gen,
scheduler_name_gen,
gen_seed,
guidance_scale,
pos_prompt_gen,
neg_prompt_gen,
model_path_gen,
vae_path_gen,
lora_path_gen,
],
[canvas_gen, input_image_gen, output_image_gen, mask_gen, intermediate_latents_gen]
)
undo_button_gen.click(
undo_points,
[original_image_gen, mask_gen],
[input_image_gen, selected_points_gen]
)
run_button_gen.click(
run_drag_gen,
[
n_inference_step_gen,
scheduler_name_gen,
original_image_gen, # the original image generated by the diffusion model
input_image_gen, # image with clicking, masking, etc.
intermediate_latents_gen,
guidance_scale,
mask_gen,
pos_prompt_gen,
neg_prompt_gen,
selected_points_gen,
inversion_strength_gen,
lam_gen,
latent_lr_gen,
n_pix_step_gen,
model_path_gen,
vae_path_gen,
lora_path_gen,
start_step_gen,
start_layer_gen,
create_gif_checkbox,
create_tracking_point_checkbox,
gif_interval,
gif_fps
],
[output_image_gen]
)
clear_all_button_gen.click(
clear_all_gen,
[gr.Number(value=LENGTH, visible=False, precision=0)],
[canvas_gen,
input_image_gen,
output_image_gen,
selected_points_gen,
original_image_gen,
mask_gen,
intermediate_latents_gen,
]
)
demo.queue().launch(share=True, debug=True)