d-edit / app.py
afeng's picture
second
f4f90db
raw
history blame
17.1 kB
import os
import copy
from PIL import Image
import matplotlib
import numpy as np
import gradio as gr
from utils import load_mask, load_mask_edit
from utils_mask import process_mask_to_follow_priority, mask_union, visualize_mask_list_clean
from pathlib import Path
import subprocess
from PIL import Image
from functools import partial
from main import run_main
LENGTH=512 #length of the square area displaying/editing images
TRANSPARENCY = 150 # transparency of the mask in display
def add_mask(mask_np_list_updated, mask_label_list):
mask_new = np.zeros_like(mask_np_list_updated[0])
mask_np_list_updated.append(mask_new)
mask_label_list.append("new")
return mask_np_list_updated, mask_label_list
def create_segmentation(mask_np_list):
viridis = matplotlib.pyplot.get_cmap(name = 'viridis', lut = len(mask_np_list))
segmentation = 0
for i, m in enumerate(mask_np_list):
color = matplotlib.colors.to_rgb(viridis(i))
color_mat = np.ones_like(m)
color_mat = np.stack([color_mat*color[0], color_mat*color[1],color_mat*color[2] ], axis = 2)
color_mat = color_mat * m[:,:,np.newaxis]
segmentation += color_mat
segmentation = Image.fromarray(np.uint8(segmentation*255))
return segmentation
def load_mask_ui(input_folder="example_tmp",load_edit = False):
if not load_edit:
mask_list, mask_label_list = load_mask(input_folder)
else:
mask_list, mask_label_list = load_mask_edit(input_folder)
mask_np_list = []
for m in mask_list:
mask_np_list. append( m.cpu().numpy())
return mask_np_list, mask_label_list
def load_image_ui(load_edit, input_folder="example_tmp"):
try:
for img_path in Path(input_folder).iterdir():
if img_path.name in ["img_512.png"]:
image = Image.open(img_path)
mask_np_list, mask_label_list = load_mask_ui(input_folder, load_edit = load_edit)
image = image.convert('RGB')
segmentation = create_segmentation(mask_np_list)
print("!!", len(mask_np_list))
return image, segmentation, mask_np_list, mask_label_list, image
except:
print("Image folder invalid: The folder should contain image.png")
return None, None, None, None, None
def run_edit_text(
num_tokens,
num_sampling_steps,
strength,
edge_thickness,
tgt_prompt,
tgt_idx,
guidance_scale,
input_folder="example_tmp"
):
subprocess.run(["python",
"main.py" ,
"--text",
"--name={}".format(input_folder),
"--dpm={}".format("sd"),
"--resolution={}".format(512),
"--load_trained",
"--num_tokens={}".format(num_tokens),
"--seed={}".format(2024),
"--guidance_scale={}".format(guidance_scale),
"--num_sampling_step={}".format(num_sampling_steps),
"--strength={}".format(strength),
"--edge_thickness={}".format(edge_thickness),
"--num_imgs={}".format(2),
"--tgt_prompt={}".format(tgt_prompt) ,
"--tgt_index={}".format(tgt_idx)
])
return Image.open(os.path.join(input_folder, "text", "out_text_0.png"))
def run_optimization(
num_tokens,
embedding_learning_rate,
max_emb_train_steps,
diffusion_model_learning_rate,
max_diffusion_train_steps,
train_batch_size,
gradient_accumulation_steps,
input_folder = "example_tmp"
):
subprocess.run(["python",
"main.py" ,
"--name={}".format(input_folder),
"--dpm={}".format("sd"),
"--resolution={}".format(512),
"--num_tokens={}".format(num_tokens),
"--embedding_learning_rate={}".format(embedding_learning_rate),
"--diffusion_model_learning_rate={}".format(diffusion_model_learning_rate),
"--max_emb_train_steps={}".format(max_emb_train_steps),
"--max_diffusion_train_steps={}".format(max_diffusion_train_steps),
"--train_batch_size={}".format(train_batch_size),
"--gradient_accumulation_steps={}".format(gradient_accumulation_steps)
])
return
def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
backimg_solid_np = np.array(backimg)
bimg = backimg.copy()
fimg = foreimg.copy()
fimg.putalpha(transparency)
bimg.paste(fimg, (0,0), fimg)
bimg_np = np.array(bimg)
mask_np = mask_np[:,:,np.newaxis]
try:
new_img_np = bimg_np*mask_np + (1-mask_np)* backimg_solid_np
return Image.fromarray(new_img_np)
except:
import pdb; pdb.set_trace()
def show_segmentation(image, segmentation, flag):
if flag is False:
flag = True
mask_np = np.ones([image.size[0],image.size[1]]).astype(np.uint8)
image_edit = transparent_paste_with_mask(image, segmentation, mask_np ,transparency = TRANSPARENCY)
return image_edit, flag
else:
flag = False
return image,flag
def edit_mask_add(canvas, image, idx, mask_np_list):
mask_sel = mask_np_list[idx]
mask_new = np.uint8(canvas["mask"][:, :, 0]/ 255.)
mask_np_list_updated = []
for midx, m in enumerate(mask_np_list):
if midx == idx:
mask_np_list_updated.append(mask_union(mask_sel, mask_new))
else:
mask_np_list_updated.append(m)
priority_list = [0 for _ in range(len(mask_np_list_updated))]
priority_list[idx] = 1
mask_np_list_updated = process_mask_to_follow_priority(mask_np_list_updated, priority_list)
mask_ones = np.ones([mask_sel.shape[0], mask_sel.shape[1]]).astype(np.uint8)
segmentation = create_segmentation(mask_np_list_updated)
image_edit = transparent_paste_with_mask(image, segmentation, mask_ones ,transparency = TRANSPARENCY)
return mask_np_list_updated, image_edit
def slider_release(index, image, mask_np_list_updated, mask_label_list):
if index > len(mask_np_list_updated):
return image, "out of range"
else:
mask_np = mask_np_list_updated[index]
mask_label = mask_label_list[index]
segmentation = create_segmentation(mask_np_list_updated)
new_image = transparent_paste_with_mask(image, segmentation, mask_np, transparency = TRANSPARENCY)
return new_image, mask_label
def save_as_orig_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
try:
assert np.all(sum(mask_np_list_updated)==1)
except:
print("please check mask")
# plt.imsave( "out_mask.png", mask_list_edit[0])
import pdb; pdb.set_trace()
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
# np.save(os.path.join(input_folder, "maskEDIT{}_{}.npy".format(midx, mask_label)),mask )
np.save(os.path.join(input_folder, "mask{}_{}.npy".format(midx, mask_label)),mask )
savepath = os.path.join(input_folder, "seg_current.png")
visualize_mask_list_clean(mask_np_list_updated, savepath)
def save_as_edit_mask(mask_np_list_updated, mask_label_list, input_folder="example_tmp"):
try:
assert np.all(sum(mask_np_list_updated)==1)
except:
print("please check mask")
# plt.imsave( "out_mask.png", mask_list_edit[0])
import pdb; pdb.set_trace()
for midx, (mask, mask_label) in enumerate(zip(mask_np_list_updated, mask_label_list)):
np.save(os.path.join(input_folder, "maskEdited{}_{}.npy".format(midx, mask_label)), mask)
savepath = os.path.join(input_folder, "seg_edited.png")
visualize_mask_list_clean(mask_np_list_updated, savepath)
import shutil
if os.path.isdir("./example_tmp"):
shutil.rmtree("./example_tmp")
from segment import run_segmentation
with gr.Blocks() as demo:
image = gr.State() # store mask
image_loaded = gr.State()
segmentation = gr.State()
mask_np_list = gr.State([])
mask_label_list = gr.State([])
mask_np_list_updated = gr.State([])
true = gr.State(True)
false = gr.State(False)
with gr.Row():
gr.Markdown("""# D-Edit""")
with gr.Tab(label="1 Edit mask"):
with gr.Row():
with gr.Column():
canvas = gr.Image(value = "./img.png", type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
segment_button = gr.Button("1.1 Run segmentation")
segment_button.click(run_segmentation,
[canvas] ,
[] )
text_button = gr.Button("1.2 Load original masks")
text_button.click(load_image_ui,
[ false] ,
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
load_edit_button = gr.Button("1.2 Load edited masks")
load_edit_button.click(load_image_ui,
[ true] ,
[image_loaded, segmentation, mask_np_list, mask_label_list, canvas] )
show_segment = gr.Checkbox(label = "Show Segmentation")
flag = gr.State(False)
show_segment.select(show_segmentation,
[image_loaded, segmentation, flag],
[canvas, flag])
# mask_np_list_updated.value = copy.deepcopy(mask_np_list.value) #!!
mask_np_list_updated = mask_np_list
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Draw Mask</p>""")
slider = gr.Slider(0, 20, step=1, interactive=True)
label = gr.Textbox()
slider.release(slider_release,
inputs = [slider, image_loaded, mask_np_list_updated, mask_label_list],
outputs= [canvas, label]
)
add_button = gr.Button("Add")
add_button.click( edit_mask_add,
[canvas, image_loaded, slider, mask_np_list_updated] ,
[mask_np_list_updated, canvas]
)
save_button2 = gr.Button("Set and Save as edited masks")
save_button2.click( save_as_edit_mask,
[mask_np_list_updated, mask_label_list] ,
[] )
save_button = gr.Button("Set and Save as original masks")
save_button.click( save_as_orig_mask,
[mask_np_list_updated, mask_label_list] ,
[] )
back_button = gr.Button("Back to current seg")
back_button.click( load_mask_ui,
[] ,
[ mask_np_list_updated,mask_label_list] )
add_mask_button = gr.Button("Add new empty mask")
add_mask_button.click(add_mask,
[mask_np_list_updated, mask_label_list] ,
[mask_np_list_updated, mask_label_list] )
with gr.Tab(label="2 Optimization"):
with gr.Row():
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
num_tokens = gr.Number(value="5", label="num tokens to represent each object", interactive= True)
embedding_learning_rate = gr.Textbox(value="0.0001", label="Embedding optimization: Learning rate", interactive= True )
max_emb_train_steps = gr.Number(value="200", label="embedding optimization: Training steps", interactive= True )
diffusion_model_learning_rate = gr.Textbox(value="0.00005", label="UNet Optimization: Learning rate", interactive= True )
max_diffusion_train_steps = gr.Number(value="200", label="UNet Optimization: Learning rate: Training steps", interactive= True )
train_batch_size = gr.Number(value="5", label="Batch size", interactive= True )
gradient_accumulation_steps=gr.Number(value="5", label="Gradient accumulation", interactive= True )
add_button = gr.Button("Run optimization")
def run_optimization_wrapper (
num_tokens,
embedding_learning_rate ,
max_emb_train_steps ,
diffusion_model_learning_rate ,
max_diffusion_train_steps,
train_batch_size,
gradient_accumulation_steps
):
run_optimization = partial(
run_main,
num_tokens=int(num_tokens),
embedding_learning_rate = float(embedding_learning_rate),
max_emb_train_steps = int(max_emb_train_steps),
diffusion_model_learning_rate= float(diffusion_model_learning_rate),
max_diffusion_train_steps = int(max_diffusion_train_steps),
train_batch_size=int(train_batch_size),
gradient_accumulation_steps=int(gradient_accumulation_steps)
)
run_optimization()
add_button.click(run_optimization_wrapper,
inputs = [
num_tokens,
embedding_learning_rate ,
max_emb_train_steps ,
diffusion_model_learning_rate ,
max_diffusion_train_steps,
train_batch_size,
gradient_accumulation_steps
],
outputs = []
)
with gr.Tab(label="3 Editing"):
with gr.Tab(label="3.1 Text-based editing"):
with gr.Row():
with gr.Column():
canvas_text_edit = gr.Image(value = None, type = "pil", label="Editing results", show_label=True)
# canvas_text_edit = gr.Gallery(label = "Edited results")
with gr.Column():
gr.Markdown("""<p style="text-align: center; font-size: 20px">Editing setting (SD)</p>""")
tgt_prompt = gr.Textbox(value="White bag", label="Editing: Text prompt", interactive= True )
tgt_index = gr.Number(value="0", label="Editing: Object index", interactive= True )
guidance_scale = gr.Textbox(value="6", label="Editing: CFG guidance scale", interactive= True )
num_sampling_steps = gr.Number(value="50", label="Editing: Sampling steps", interactive= True )
edge_thickness = gr.Number(value="10", label="Editing: Edge thickness", interactive= True )
strength = gr.Textbox(value="0.5", label="Editing: Mask strength", interactive= True )
add_button = gr.Button("Run Editing")
run_edit_text = partial(
run_main,
load_trained=True,
text=True,
num_tokens = int(num_tokens.value),
guidance_scale = float(guidance_scale.value),
num_sampling_steps = int(num_sampling_steps.value),
strength = float(strength.value),
edge_thickness = int(edge_thickness.value),
num_imgs = 1,
tgt_prompt = tgt_prompt.value,
tgt_index = int(tgt_index.value)
)
add_button.click(run_edit_text,
inputs = [],
outputs = [canvas_text_edit]
)
def load_pil_img():
from PIL import Image
return Image.open("example_tmp/text/out_text_0.png")
load_button = gr.Button("Load results")
load_button.click(load_pil_img,
inputs = [],
outputs = [canvas_text_edit]
)
demo.queue().launch(share=True, debug=True)