import os import PIL import cv2 import time import torch import spaces import numpy as np import gradio as gr from PIL import Image from torch import autocast from contextlib import nullcontext from itertools import islice from omegaconf import OmegaConf from einops import rearrange, repeat from pytorch_lightning import seed_everything from ldm.util import instantiate_from_config from ldm.models.diffusion.dpm_solver import DPMSolverSampler from gradio_image_annotation import image_annotator DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") CONFIG_PATH = "./configs/stable-diffusion/v2-inference.yaml" CKPT_PATH = "./ckpt/v2-1_512-ema-pruned.ckpt" if not os.path.exists(CKPT_PATH): # automatically download the checkpoint if it doesn't exist print(f"Checkpoint {CKPT_PATH} not found, downloading from huggingface") os.system(f"wget -O {CKPT_PATH} https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt") CONFIG = OmegaConf.load(CONFIG_PATH) def load_img(image, SCALE, pad=False, seg_map=False, target_size=None): if seg_map: # Load the input image and segmentation map # image = Image.open(path).convert("RGB") # seg_map = Image.open(seg).convert("1") seg_map = seg_map.convert("1") # Get the width and height of the original image w, h = image.size # Calculate the aspect ratio of the original image aspect_ratio = h / w # Determine the new dimensions for resizing the image while maintaining aspect ratio if aspect_ratio > 1: new_w = int(SCALE * 256 / aspect_ratio) new_h = int(SCALE * 256) else: new_w = int(SCALE * 256) new_h = int(SCALE * 256 * aspect_ratio) # Resize the image and the segmentation map to the new dimensions image_resize = image.resize((new_w, new_h)) segmentation_map_resize = cv2.resize(np.array(seg_map).astype(np.uint8), (new_w, new_h), interpolation=cv2.INTER_NEAREST) # Pad the segmentation map to match the target size padded_segmentation_map = np.zeros((target_size[1], target_size[0])) start_x = (target_size[1] - segmentation_map_resize.shape[0]) // 2 start_y = (target_size[0] - segmentation_map_resize.shape[1]) // 2 padded_segmentation_map[start_x : start_x + segmentation_map_resize.shape[0], start_y : start_y + segmentation_map_resize.shape[1]] = ( segmentation_map_resize ) # Create a new RGB image with the target size and place the resized image in the center padded_image = Image.new("RGB", target_size) start_x = (target_size[0] - image_resize.width) // 2 start_y = (target_size[1] - image_resize.height) // 2 padded_image.paste(image_resize, (start_x, start_y)) # Update the variable "image" to contain the final padded image image = padded_image else: # image = Image.open(path).convert("RGB") w, h = image.size # print(f"loaded input image of size ({w}, {h}) from {path}") w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 w = h = 512 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) if pad or seg_map: return 2.0 * image - 1.0, new_w, new_h, padded_segmentation_map return 2.0 * image - 1.0, w, h def load_model_and_get_prompt_embedding(model, scale, device, prompts, inv=False): if inv: inv_emb = model.get_learned_conditioning(prompts, inv) c = uc = inv_emb else: inv_emb = None if scale != 1.0: uc = model.get_learned_conditioning([""]) else: uc = None c = model.get_learned_conditioning(prompts) return c, uc, inv_emb def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) def load_model_from_config(config, ckpt, gpu, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location=gpu) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.eval() return model MODEL = load_model_from_config(CONFIG, CKPT_PATH, DEVICE) MODEL.to(device=DEVICE) @spaces.GPU(duration=60) def tficon(img_with_mask, ref_img, seg, prompt, dpm_order, dpm_steps, tau_a, tau_b, domain, seed, scale): init_img = img_with_mask["image"] n_samples = 1 precision = "autocast" ddim_eta = 0.0 dpm_order = int(dpm_order[0]) scale = scale device = DEVICE model = MODEL batch_size = n_samples sampler = DPMSolverSampler(model) seed_everything(seed) # img = cv2.imread(mask, 0) # # Threshold the image to create binary image # _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY) # # Find the contours of the white region in the image # contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # # Find the bounding rectangle of the largest contour # x, y, new_w, new_h = cv2.boundingRect(contours[0]) # Calculate the center of the rectangle bbox = img_with_mask["boxes"][0] x = bbox["xmin"] y = bbox["ymin"] new_w = bbox["xmax"] - bbox["xmin"] new_h = bbox["ymax"] - bbox["ymin"] center_x = x + new_w / 2 center_y = y + new_h / 2 # Calculate the percentage from the top and left center_row_from_top = round(center_y / 512, 2) center_col_from_left = round(center_x / 512, 2) aspect_ratio = new_h / new_w if aspect_ratio > 1: mask_scale = new_w * aspect_ratio / 256 mask_scale = new_h / 256 else: mask_scale = new_w / 256 mask_scale = new_h / (aspect_ratio * 256) # mask_scale = round(mask_scale, 2) # ============================================================================================= data = [batch_size * [prompt]] # read background image init_image, target_width, target_height = load_img(init_img, mask_scale) init_image = repeat(init_image.to(device), "1 ... -> b ...", b=batch_size) save_image = init_image.clone() # read foreground image and its segmentation map ref_image, width, height, segmentation_map = load_img(ref_img, mask_scale, seg_map=seg, target_size=(target_width, target_height)) ref_image = repeat(ref_image.to(device), "1 ... -> b ...", b=batch_size) segmentation_map_orig = repeat(torch.tensor(segmentation_map)[None, None, ...].to(device), "1 1 ... -> b 4 ...", b=batch_size) segmentation_map_save = repeat(torch.tensor(segmentation_map)[None, None, ...].to(device), "1 1 ... -> b 3 ...", b=batch_size) segmentation_map = segmentation_map_orig[:, :, ::8, ::8].to(device) top_rr = int((0.5 * (target_height - height)) / target_height * init_image.shape[2]) # xx% from the top bottom_rr = int((0.5 * (target_height + height)) / target_height * init_image.shape[2]) left_rr = int((0.5 * (target_width - width)) / target_width * init_image.shape[3]) # xx% from the left right_rr = int((0.5 * (target_width + width)) / target_width * init_image.shape[3]) center_row_rm = int(center_row_from_top * target_height) center_col_rm = int(center_col_from_left * target_width) step_height2, remainder = divmod(height, 2) step_height1 = step_height2 + remainder step_width2, remainder = divmod(width, 2) step_width1 = step_width2 + remainder # compositing in pixel space for same-domain composition save_image[:, :, center_row_rm - step_height1 : center_row_rm + step_height2, center_col_rm - step_width1 : center_col_rm + step_width2] = ( save_image[ :, :, center_row_rm - step_height1 : center_row_rm + step_height2, center_col_rm - step_width1 : center_col_rm + step_width2 ].clone() * (1 - segmentation_map_save[:, :, top_rr:bottom_rr, left_rr:right_rr]) + ref_image[:, :, top_rr:bottom_rr, left_rr:right_rr].clone() * segmentation_map_save[:, :, top_rr:bottom_rr, left_rr:right_rr] ) # save the mask and the pixel space composited image save_mask = torch.zeros_like(init_image) save_mask[:, :, center_row_rm - step_height1 : center_row_rm + step_height2, center_col_rm - step_width1 : center_col_rm + step_width2] = 1 # image = Image.fromarray(((save_image/torch.max(save_image.max(), abs(save_image.min())) + 1) * 127.5)[0].permute(1,2,0).to(dtype=torch.uint8).cpu().numpy()) precision_scope = autocast if precision == "autocast" else nullcontext # image composition with torch.no_grad(): with precision_scope("cuda"): for prompts in data: print(prompts) c, uc, inv_emb = load_model_and_get_prompt_embedding(model, scale, device, prompts, inv=True) if domain == "Real Domain": # same domain init_image = save_image T1 = time.time() init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # ref's location in ref image in the latent space top_rr = int((0.5 * (target_height - height)) / target_height * init_latent.shape[2]) bottom_rr = int((0.5 * (target_height + height)) / target_height * init_latent.shape[2]) left_rr = int((0.5 * (target_width - width)) / target_width * init_latent.shape[3]) right_rr = int((0.5 * (target_width + width)) / target_width * init_latent.shape[3]) new_height = bottom_rr - top_rr new_width = right_rr - left_rr step_height2, remainder = divmod(new_height, 2) step_height1 = step_height2 + remainder step_width2, remainder = divmod(new_width, 2) step_width1 = step_width2 + remainder center_row_rm = int(center_row_from_top * init_latent.shape[2]) center_col_rm = int(center_col_from_left * init_latent.shape[3]) param = [ max(0, int(center_row_rm - step_height1)), min(init_latent.shape[2] - 1, int(center_row_rm + step_height2)), max(0, int(center_col_rm - step_width1)), min(init_latent.shape[3] - 1, int(center_col_rm + step_width2)), ] ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref_image)) shape = [init_latent.shape[1], init_latent.shape[2], init_latent.shape[3]] z_enc, _ = sampler.sample( steps = dpm_steps, inv_emb = inv_emb, unconditional_conditioning = uc, conditioning = c, batch_size = n_samples, shape = shape, verbose = False, unconditional_guidance_scale = scale, eta = ddim_eta, order = dpm_order, x_T = init_latent, width = width, height = height, DPMencode = True, ) z_ref_enc, _ = sampler.sample( steps = dpm_steps, inv_emb = inv_emb, unconditional_conditioning = uc, conditioning = c, batch_size = n_samples, shape = shape, verbose = False, unconditional_guidance_scale = scale, eta = ddim_eta, order = dpm_order, x_T = ref_latent, DPMencode = True, width = width, height = height, ref = True, ) samples_orig = z_enc.clone() # inpainting in XOR region of M_seg and M_mask z_enc[:, :, param[0] : param[1], param[2] : param[3]] = z_enc[ :, :, param[0] : param[1], param[2] : param[3] ] * segmentation_map[:, :, top_rr:bottom_rr, left_rr:right_rr] + torch.randn( (1, 4, bottom_rr - top_rr, right_rr - left_rr), device=device ) * (1 - segmentation_map[:, :, top_rr:bottom_rr, left_rr:right_rr]) samples_for_cross = samples_orig.clone() samples_ref = z_ref_enc.clone() samples = z_enc.clone() # noise composition if domain == "Cross Domain": samples[:, :, param[0] : param[1], param[2] : param[3]] = torch.randn( (1, 4, bottom_rr - top_rr, right_rr - left_rr), device=device ) # apply the segmentation mask on the noise samples[:, :, param[0] : param[1], param[2] : param[3]] = ( samples[:, :, param[0] : param[1], param[2] : param[3]].clone() * (1 - segmentation_map[:, :, top_rr:bottom_rr, left_rr:right_rr]) + z_ref_enc[:, :, top_rr:bottom_rr, left_rr:right_rr].clone() * segmentation_map[:, :, top_rr:bottom_rr, left_rr:right_rr] ) mask = torch.zeros_like(z_enc, device=device) mask[:, :, param[0] : param[1], param[2] : param[3]] = 1 samples, _ = sampler.sample( steps = dpm_steps, inv_emb = inv_emb, conditioning = c, batch_size = n_samples, shape = shape, verbose = False, unconditional_guidance_scale = scale, unconditional_conditioning = uc, eta = ddim_eta, order = dpm_order, x_T = [samples_orig, samples.clone(), samples_for_cross, samples_ref, samples, init_latent], width = width, height = height, segmentation_map = segmentation_map, param = param, mask = mask, target_height = target_height, target_width = target_width, center_row_rm = center_row_from_top, center_col_rm = center_col_from_left, tau_a = tau_a, tau_b = tau_b, ) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) T2 = time.time() print("Running Time: %s s" % (T2 - T1)) for x_sample in x_samples: x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") img = Image.fromarray(x_sample.astype(np.uint8)) # img.save(os.path.join(sample_path, f"{base_count:05}_{prompts[0]}.png")) return img def read_content(file_path: str) -> str: """read the content of target file""" with open(file_path, "r", encoding="utf-8") as f: content = f.read() return content example = {} ref_dir = "./gradio/foreground" image_dir = "./gradio/background" seg_dir = "./gradio/seg_foreground" image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)] image_list.sort() ref_list = [os.path.join(ref_dir, file) for file in os.listdir(ref_dir)] ref_list.sort() seg_list = [os.path.join(seg_dir, file) for file in os.listdir(seg_dir)] seg_list.sort() reference_list = [[ref_img, ref_mask] for ref_img, ref_mask in zip(ref_list, seg_list)] image_list = [ { "image": image, "boxes": [ { "xmin" : 128, "ymin" : 128, "xmax" : 384, "ymax" : 384, "label": "Mask", "color": (250, 0, 0), } ], } for image in image_list ] def update_mask(image): print("update mask") bbox = image["boxes"][0] label = image["boxes"][0]["label"] xmin = bbox["xmin"] ymin = bbox["ymin"] xmax = bbox["xmax"] ymax = bbox["ymax"] coords = [xmin, ymin, xmax, ymax] return (image["image"], [(coords, label)]) if __name__ == "__main__": with gr.Blocks() as demo: gr.HTML( """

🦄TF-ICON: Diffusion-Based Training-Free Cross-Domain Image Composition

Paper Page Duplicate Space

This is an unofficial demo for the paper 'TF-ICON: Diffusion-Based Training-Free Cross-Domain Image Composition'.

""" ) with gr.Row(): with gr.Column(): # back_image_invisible = gr.Image(elem_id="image_upload", type="pil", label="Background Image", height=512, visible=False) image = image_annotator( None, label_list=["Mask"], label_colors=[(255, 0, 0)], height=512, image_type="pil", ) # back_image_invisible.change(fn=set_image, inputs=[back_image_invisible, image]) mask_btn = gr.Button("Generate Mask") reference = gr.Image(elem_id="image_upload", type="pil", label="Foreground Image", height=512) with gr.Row(): # guidance = gr.Slider(label="Guidance scale", value=5, maximum=15,interactive=True) steps = gr.Slider(label="Steps", value=50, minimum=2, maximum=75, step=1, interactive=True) seed = gr.Slider(0, 10000, label="Seed (0 = random)", value=3407, step=1) with gr.Row(): tau_a = gr.Slider( label="tau_a", value=0.4, minimum=0.0, maximum=1.0, step=0.1, interactive=True, info="Foreground Attention Injection", ) tau_b = gr.Slider( label="tau_b", value=0.8, minimum=0.0, maximum=1.0, step=0.1, interactive=True, info="Background Preservation" ) with gr.Row(): scale = gr.Slider( label="CFG", value=2.5, minimum=0.0, maximum=15.0, step=0.5, interactive=True, info="CFG=2.5 for real domain CFG>=5.0 for cross domain", ) dpm_order = gr.CheckboxGroup(["1", "2", "3"], value="2", label="DPM Solver Order") domain = gr.Radio( ["Cross Domain", "Real Domain"], value="Real Domain", label="Domain", info="When background is real image, choose Real Domain; otherwise, choose Cross Domain", ) prompt = gr.Textbox(label="Prompt", info="an oil painting (or a pencil drawing) of a panda") # .style(height=512) btn = gr.Button("Run!") # with gr.Column(): mask = gr.AnnotatedImage( label="Composition Region", # info="Setting mask for composition region: first click for the top left corner, second click for the bottom right corner", color_map={"Region for Composing Object": "#9987FF", "Click Second Point for Mask": "#f44336"}, height=512, ) mask_btn.click(fn=update_mask, inputs=[image], outputs=[mask]) # image.select(get_select_coordinates, image, mask) seg = gr.Image(elem_id="image_upload", type="pil", label="Segmentation Mask for Foreground", height=512) image_out = gr.Image(label="Output", elem_id="output-img", height=512) # with gr.Group(elem_id="share-btn-container"): # community_icon = gr.HTML(community_icon_html, visible=True) # loading_icon = gr.HTML(loading_icon_html, visible=True) # share_button = gr.Button("Share to community", elem_id="share-btn", visible=True) with gr.Row(): with gr.Column(): gr.Examples(image_list, inputs=[image], label="Examples - Background Image", examples_per_page=12) with gr.Column(): gr.Examples(reference_list, inputs=[reference, seg], label="Examples - Foreground Image", examples_per_page=12) btn.click(fn=tficon, inputs=[image, reference, seg, prompt, dpm_order, steps, tau_a, tau_b, domain, seed, scale], outputs=[image_out]) demo.queue(max_size=10).launch()