import os import cv2 import numpy as np import gradio as gr from copy import deepcopy from einops import rearrange from types import SimpleNamespace import datetime import PIL from PIL import Image from PIL.ImageOps import exif_transpose import torch import torch.nn.functional as F from diffusers import DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler from drag_pipeline import DragPipeline from torchvision.utils import save_image from pytorch_lightning import seed_everything from .drag_utils import drag_diffusion_update, drag_diffusion_update_gen from .lora_utils import train_lora from .attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl import imageio # -------------- general UI functionality -------------- def clear_all(length=480): return gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ [], None, None def clear_all_gen(length=480): return gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ [], None, None, None def mask_image(image, mask, color=[255,0,0], alpha=0.5): """ Overlay mask on image for visualization purpose. Args: image (H, W, 3) or (H, W): input image mask (H, W): mask to be overlaid color: the color of overlaid mask alpha: the transparency of the mask """ out = deepcopy(image) img = deepcopy(image) img[mask == 1] = color out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) return out def store_img(img, length=512): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. height,width,_ = image.shape image = Image.fromarray(image) image = exif_transpose(image) image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR) mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST) image = np.array(image) if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() # when new image is uploaded, `selected_points` should be empty return image, [], masked_img, mask # once user upload an image, the original image is stored in `original_image` # the same image is displayed in `input_image` for point clicking purpose def store_img_gen(img): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. image = Image.fromarray(image) image = exif_transpose(image) image = np.array(image) if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() # when new image is uploaded, `selected_points` should be empty return image, [], masked_img, mask # user click the image to get points, and show the points on the image def get_points(img, sel_pix, evt: gr.SelectData): img_copy = img.copy() if isinstance(img, np.ndarray) else np.array(img) # collect the selected point sel_pix.append(evt.index) # draw points points = [] for idx, point in enumerate(sel_pix): if idx % 2 == 0: # draw a red circle at the handle point cv2.circle(img_copy, tuple(point), 10, (255, 0, 0), -1) else: # draw a blue circle at the handle point cv2.circle(img_copy, tuple(point), 10, (0, 0, 255), -1) points.append(tuple(point)) # draw an arrow from handle point to target point if len(points) == 2: cv2.arrowedLine(img_copy, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) points = [] return img_copy if isinstance(img, np.ndarray) else np.array(img_copy) # clear all handle/target points def undo_points(original_image, mask): if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = original_image.copy() return masked_img, [] # ------------------------------------------------------ # ----------- dragging user-input image utils ----------- def train_lora_interface(original_image, prompt, model_path, vae_path, lora_path, lora_step, lora_lr, lora_rank, progress=gr.Progress()): train_lora( original_image, prompt, model_path, vae_path, lora_path, lora_step, lora_lr, lora_rank, progress) return "Training LoRA Done!" def preprocess_image(image, device): image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] image = rearrange(image, "h w c -> 1 c h w") image = image.to(device) return image def run_drag(source_image, image_with_clicks, mask, prompt, points, inversion_strength, lam, latent_lr, n_pix_step, model_path, vae_path, lora_path, start_step, start_layer, create_gif_checkbox, gif_interval, save_dir="./results" ): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) model = DragPipeline.from_pretrained(model_path, scheduler=scheduler).to(device) # call this function to override unet forward function, # so that intermediate features are returned after forward model.modify_unet_forward() # set vae if vae_path != "default": model.vae = AutoencoderKL.from_pretrained( vae_path ).to(model.vae.device, model.vae.dtype) # initialize parameters seed = 42 # random seed used by a lot of people for unknown reason seed_everything(seed) args = SimpleNamespace() args.prompt = prompt args.points = points args.n_inference_step = 50 args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) args.guidance_scale = 1.0 args.unet_feature_idx = [3] args.sup_res = 256 args.r_m = 1 args.r_p = 3 args.lam = lam args.lr = latent_lr args.n_pix_step = n_pix_step args.create_gif_checkbox = create_gif_checkbox args.gif_interval = gif_interval print(args) full_h, full_w = source_image.shape[:2] source_image = preprocess_image(source_image, device) image_with_clicks = preprocess_image(image_with_clicks, device) # set lora if lora_path == "": print("applying default parameters") model.unet.set_default_attn_processor() else: print("applying lora: " + lora_path) model.unet.load_attn_procs(lora_path) # invert the source image # the latent code resolution is too small, only 64*64 invert_code = model.invert(source_image, prompt, guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step) mask = torch.from_numpy(mask).float() / 255. mask[mask > 0.0] = 1.0 mask = rearrange(mask, "h w -> 1 1 h w").cuda() mask = F.interpolate(mask, (args.sup_res, args.sup_res), mode="nearest") handle_points = [] target_points = [] # here, the point is in x,y coordinate for idx, point in enumerate(points): cur_point = torch.tensor([point[1] / full_h, point[0] / full_w]) * args.sup_res cur_point = torch.round(cur_point) if idx % 2 == 0: handle_points.append(cur_point) else: target_points.append(cur_point) print('handle points:', handle_points) print('target points:', target_points) init_code = invert_code init_code_orig = deepcopy(init_code) model.scheduler.set_timesteps(args.n_inference_step) t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] # update according to the given supervision updated_init_code, gif_updated_init_code = drag_diffusion_update(model, init_code, t, handle_points, target_points, mask, args) # hijack the attention module # inject the reference branch to guide the generation editor = MutualSelfAttentionControl(start_step=start_step, start_layer=start_layer, total_steps=args.n_inference_step, guidance_scale=args.guidance_scale) if lora_path == "": register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') else: register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') # inference the synthesized image gen_image = model( prompt=args.prompt, batch_size=2, latents=torch.cat([init_code_orig, updated_init_code], dim=0), guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step )[1].unsqueeze(dim=0) # if gif, inference the synthesized image for each step and save them to gif if args.create_gif_checkbox: out_frames = [] for step_updated_init_code in gif_updated_init_code: gen_image = model( prompt=args.prompt, batch_size=1, latents=step_updated_init_code, guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step ).unsqueeze(dim=0) out_frame = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_frame = (out_frame * 255).astype(np.uint8) out_frames.append(out_frame) #save the gif if not os.path.isdir(save_dir): os.mkdir(save_dir) save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") imageio.mimsave(os.path.join(save_dir, save_prefix + '.gif'), out_frames, fps=10) # save the original image, user editing instructions, synthesized image save_result = torch.cat([ source_image * 0.5 + 0.5, torch.ones((1,3,512,25)).cuda(), image_with_clicks * 0.5 + 0.5, torch.ones((1,3,512,25)).cuda(), gen_image[0:1] ], dim=-1) if not os.path.isdir(save_dir): os.mkdir(save_dir) save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_image = (out_image * 255).astype(np.uint8) return out_image # ------------------------------------------------------- # ----------- dragging generated image utils ----------- # once the user generated an image # it will be displayed on mask drawing-areas and point-clicking area def gen_img( length, # length of the window displaying the image height, # height of the generated image width, # width of the generated image n_inference_step, scheduler_name, seed, guidance_scale, prompt, neg_prompt, model_path, vae_path, lora_path): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device) if scheduler_name == "DDIM": scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) elif scheduler_name == "DPM++2M": scheduler = DPMSolverMultistepScheduler.from_config( model.scheduler.config ) elif scheduler_name == "DPM++2M_karras": scheduler = DPMSolverMultistepScheduler.from_config( model.scheduler.config, use_karras_sigmas=True ) else: raise NotImplementedError("scheduler name not correct") model.scheduler = scheduler # call this function to override unet forward function, # so that intermediate features are returned after forward model.modify_unet_forward() # set vae if vae_path != "default": model.vae = AutoencoderKL.from_pretrained( vae_path ).to(model.vae.device, model.vae.dtype) # set lora #if lora_path != "": # print("applying lora for image generation: " + lora_path) # model.unet.load_attn_procs(lora_path) if lora_path != "": print("applying lora: " + lora_path) model.load_lora_weights(lora_path, weight_name="lora.safetensors") # initialize init noise seed_everything(seed) init_noise = torch.randn([1, 4, height // 8, width // 8], device=device, dtype=model.vae.dtype) gen_image, intermediate_latents = model(prompt=prompt, neg_prompt=neg_prompt, num_inference_steps=n_inference_step, latents=init_noise, guidance_scale=guidance_scale, return_intermediates=True) gen_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] gen_image = (gen_image * 255).astype(np.uint8) if height < width: # need to do this due to Gradio's bug return gr.Image.update(value=gen_image, height=int(length*height/width), width=length), \ gr.Image.update(height=int(length*height/width), width=length), \ gr.Image.update(height=int(length*height/width), width=length), \ None, \ intermediate_latents else: return gr.Image.update(value=gen_image, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ gr.Image.update(value=None, height=length, width=length), \ None, \ intermediate_latents def run_drag_gen( n_inference_step, scheduler_name, source_image, image_with_clicks, intermediate_latents_gen, guidance_scale, mask, prompt, neg_prompt, points, inversion_strength, lam, latent_lr, n_pix_step, model_path, vae_path, lora_path, start_step, start_layer, create_gif_checkbox, create_tracking_points_checkbox, gif_interval, gif_fps, save_dir="./results"): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device) if scheduler_name == "DDIM": scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) elif scheduler_name == "DPM++2M": scheduler = DPMSolverMultistepScheduler.from_config( model.scheduler.config ) elif scheduler_name == "DPM++2M_karras": scheduler = DPMSolverMultistepScheduler.from_config( model.scheduler.config, use_karras_sigmas=True ) else: raise NotImplementedError("scheduler name not correct") model.scheduler = scheduler # call this function to override unet forward function, # so that intermediate features are returned after forward model.modify_unet_forward() # set vae if vae_path != "default": model.vae = AutoencoderKL.from_pretrained( vae_path ).to(model.vae.device, model.vae.dtype) # initialize parameters seed = 42 # random seed used by a lot of people for unknown reason seed_everything(seed) args = SimpleNamespace() args.prompt = prompt args.neg_prompt = neg_prompt args.points = points args.n_inference_step = n_inference_step args.n_actual_inference_step = round(n_inference_step * inversion_strength) args.guidance_scale = guidance_scale args.unet_feature_idx = [3] full_h, full_w = source_image.shape[:2] args.sup_res_h = int(0.5*full_h) args.sup_res_w = int(0.5*full_w) args.r_m = 1 args.r_p = 3 args.lam = lam args.lr = latent_lr args.n_pix_step = n_pix_step args.create_gif_checkbox = create_gif_checkbox args.create_tracking_points_checkbox = create_tracking_points_checkbox args.gif_interval = gif_interval print(args) source_image = preprocess_image(source_image, device) image_with_clicks = preprocess_image(image_with_clicks, device) # set lora #if lora_path == "": # print("applying default parameters") # model.unet.set_default_attn_processor() #else: # print("applying lora: " + lora_path) # model.unet.load_attn_procs(lora_path) if lora_path != "": print("applying lora: " + lora_path) model.load_lora_weights(lora_path, weight_name="lora.safetensors") mask = torch.from_numpy(mask).float() / 255. mask[mask > 0.0] = 1.0 mask = rearrange(mask, "h w -> 1 1 h w").cuda() mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") handle_points = [] target_points = [] # here, the point is in x,y coordinate for idx, point in enumerate(points): cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) cur_point = torch.round(cur_point) if idx % 2 == 0: handle_points.append(cur_point) else: target_points.append(cur_point) print('handle points:', handle_points) print('target points:', target_points) model.scheduler.set_timesteps(args.n_inference_step) t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] init_code = deepcopy(intermediate_latents_gen[args.n_inference_step - args.n_actual_inference_step]) init_code_orig = deepcopy(init_code) # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] # update according to the given supervision init_code = init_code.to(torch.float32) model = model.to(device, torch.float32) updated_init_code, gif_updated_init_code, handle_points_list = drag_diffusion_update_gen(model, init_code, t, handle_points, target_points, mask, args) updated_init_code = updated_init_code.to(torch.float16) model = model.to(device, torch.float16) # hijack the attention module # inject the reference branch to guide the generation editor = MutualSelfAttentionControl(start_step=start_step, start_layer=start_layer, total_steps=args.n_inference_step, guidance_scale=args.guidance_scale) if lora_path == "": register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') else: register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') # inference the synthesized image gen_image = model( prompt=args.prompt, neg_prompt=args.neg_prompt, batch_size=2, # batch size is 2 because we have reference init_code and updated init_code latents=torch.cat([init_code_orig, updated_init_code], dim=0), guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step )[1].unsqueeze(dim=0) # if gif, inference the synthesized image for each step and save them to gif if args.create_gif_checkbox: out_frames = [] print('Start Generate GIF') for step_updated_init_code in gif_updated_init_code: step_updated_init_code = step_updated_init_code.to(torch.float16) gen_image = model( prompt=args.prompt, batch_size=2, latents=torch.cat([init_code_orig, step_updated_init_code], dim=0), guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step )[1].unsqueeze(dim=0) out_frame = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_frame = (out_frame * 255).astype(np.uint8) out_frames.append(out_frame) #save the gif if not os.path.isdir(save_dir): os.mkdir(save_dir) save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") imageio.mimsave(os.path.join(save_dir, save_prefix + '.gif'), out_frames, fps=gif_fps) if args.create_tracking_points_checkbox: white_image_base = np.ones((full_h, full_w, 3), dtype=np.uint8) * 255 out_points_frames = [] previous_points = {i: None for i in range(len(handle_points))} # To store the previous locations of points print('Start Generate Tracking Points GIF', len(handle_points_list), handle_points_list) for step_idx, step_handle_points in enumerate(handle_points_list): out_points_frame = white_image_base.copy() for idx, point in enumerate(step_handle_points): current_point = (int(point[1].item()), int(point[0].item())) # Draw a circle at the handle point cv2.circle(out_points_frame, current_point, 4, (0, 0, 255), -1) # Optionally, add text labels cv2.putText(out_points_frame, f'P{idx}', (current_point[0] + 5, current_point[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) # Draw lines to show trajectory if previous_points[idx] is not None: cv2.line(out_points_frame, previous_points[idx], current_point, (0, 255, 0), 2) previous_points[idx] = current_point out_points_frame = out_points_frame.astype(np.uint8) out_points_frames.append(out_points_frame) # Save the gif if not os.path.isdir(save_dir): os.mkdir(save_dir) save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") imageio.mimsave(os.path.join(save_dir, save_prefix + '_tracking_points.gif'), out_points_frames, fps=gif_fps) # save the original image, user editing instructions, synthesized image save_result = torch.cat([ source_image * 0.5 + 0.5, torch.ones((1,3,full_h,25)).cuda(), image_with_clicks * 0.5 + 0.5, torch.ones((1,3,full_h,25)).cuda(), gen_image[0:1] ], dim=-1) if not os.path.isdir(save_dir): os.mkdir(save_dir) save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_image = (out_image * 255).astype(np.uint8) return out_image # ------------------------------------------------------