import gradio as gr import numpy as np import cv2 from PIL import Image, ImageFilter import torch import datetime import numpy as np import uuid from pipeline.pipeline_svd_DragAnything import StableVideoDiffusionPipeline from models.DragAnything import DragAnythingSDVModel from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel import cv2 import re from scipy.ndimage import distance_transform_edt import torchvision.transforms as T import torch.nn.functional as F from utils.dift_util import DIFT_Demo, SDFeaturizer from torchvision.transforms import PILToTensor import json from utils_drag import * from scipy.interpolate import interp1d, PchipInterpolator from segment_anything import sam_model_registry, SamPredictor import imageio from moviepy.editor import * print("gr file",gr.__file__) color_list = [] for i in range(20): color = np.concatenate([np.random.random(4)*255], axis=0) color_list.append(color) output_dir = "./outputs" ensure_dirname(output_dir) # SAM sam_checkpoint = "./script/sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cuda" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) def get_dift_ID(feature_map,mask): # feature_map = feature_map * 0 new_feature = [] non_zero_coordinates = np.column_stack(np.where(mask != 0)) for coord in non_zero_coordinates: # feature_map[:, coord[0], coord[1]] = 1 new_feature.append(feature_map[:, coord[0], coord[1]]) stacked_tensor = torch.stack(new_feature, dim=0) # 在维度0上进行平均池化 average_pooled_tensor = torch.mean(stacked_tensor, dim=0) return average_pooled_tensor def interpolate_trajectory(points, n_points): x = [point[0] for point in points] y = [point[1] for point in points] t = np.linspace(0, 1, len(points)) # fx = interp1d(t, x, kind='cubic') # fy = interp1d(t, y, kind='cubic') fx = PchipInterpolator(t, x) fy = PchipInterpolator(t, y) new_t = np.linspace(0, 1, n_points) new_x = fx(new_t) new_y = fy(new_t) new_points = list(zip(new_x, new_y)) return new_points def visualize_drag_v2(background_image_path, splited_tracks, width, height): trajectory_maps = [] background_image = Image.open(background_image_path).convert('RGBA') background_image = background_image.resize((width, height)) w, h = background_image.size transparent_background = np.array(background_image) transparent_background[:, :, -1] = 128 transparent_background = Image.fromarray(transparent_background) # Create a transparent layer with the same size as the background image transparent_layer = np.zeros((h, w, 4)) for splited_track in splited_tracks: if len(splited_track) > 1: splited_track = interpolate_trajectory(splited_track, 16) splited_track = splited_track[:16] for i in range(len(splited_track)-1): start_point = (int(splited_track[i][0]), int(splited_track[i][1])) end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(splited_track)-2: cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length) else: cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2) else: cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 5, (255, 0, 0, 192), -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) trajectory_maps.append(trajectory_map) return trajectory_maps, transparent_layer def gen_gaussian_heatmap(imgSize=200): circle_img = np.zeros((imgSize, imgSize), np.float32) circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1) isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32) for i in range(imgSize): for j in range(imgSize): isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp( -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2))) isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32) isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8) return isotropicGrayscaleImage def extract_dift_feature(image, dift_model): if isinstance(image, Image.Image): image = image else: image = Image.open(image).convert('RGB') prompt = '' img_tensor = (PILToTensor()(image) / 255.0 - 0.5) * 2 dift_feature = dift_model.forward(img_tensor, prompt=prompt, up_ft_index=3,ensemble_size=8) return dift_feature def get_condition(target_size=(512 , 512), points=None, original_size=(512 , 512), args="", first_frame=None, is_mask = False, side=20,model_id=None): images = [] vis_images = [] heatmap = gen_gaussian_heatmap() original_size = (original_size[1],original_size[0]) size = (target_size[1],target_size[0]) latent_size = (int(target_size[1]/8), int(target_size[0]/8)) dift_model = SDFeaturizer(sd_id=model_id) keyframe_dift = extract_dift_feature(first_frame, dift_model=dift_model) ID_images=[] ids_list={} mask_list = [] trajectory_list = [] radius_list = [] for index,point in enumerate(points): mask_name = output_dir+"/"+"mask_{}.jpg".format(index+1) trajectories = [[int(i[0]),int(i[1])] for i in point] trajectory_list.append(trajectories) first_mask = (cv2.imread(mask_name)/255).astype(np.uint8) first_mask = cv2.cvtColor(first_mask.astype(np.uint8), cv2.COLOR_RGB2GRAY) mask_list.append(first_mask) mask_322 = cv2.resize(first_mask.astype(np.uint8),(int(target_size[1]), int(target_size[0]))) _, radius = find_largest_inner_rectangle_coordinates(mask_322) radius_list.append(radius) for idxx,point in enumerate(trajectory_list[0]): new_img = np.zeros(target_size, np.uint8) vis_img = new_img.copy() ids_embedding = torch.zeros((target_size[0], target_size[1], 320)) if idxx>= args["frame_number"]: break for cc,(mask,trajectory,radius) in enumerate(zip(mask_list,trajectory_list,radius_list)): center_coordinate = trajectory[idxx] trajectory_ = trajectory[:idxx] side = min(radius,50) # ID embedding if idxx == 0: # diffusion feature mask_32 = cv2.resize(mask.astype(np.uint8),latent_size) if len(np.column_stack(np.where(mask_32 != 0)))==0: continue ids_list[cc] = get_dift_ID(keyframe_dift[0],mask_32) id_feature = ids_list[cc] else: id_feature = ids_list[cc] circle_img = np.zeros((target_size[0], target_size[1]), np.float32) circle_mask = cv2.circle(circle_img, (center_coordinate[0],center_coordinate[1]), side, 1, -1) y1 = max(center_coordinate[1]-side,0) y2 = min(center_coordinate[1]+side,target_size[0]-1) x1 = max(center_coordinate[0]-side,0) x2 = min(center_coordinate[0]+side,target_size[1]-1) if x2-x1>3 and y2-y1>3: need_map = cv2.resize(heatmap, (x2-x1, y2-y1)) new_img[y1:y2,x1:x2] = need_map.copy() if cc>=0: vis_img[y1:y2,x1:x2] = need_map.copy() if len(trajectory_) == 1: vis_img[trajectory_[0][1],trajectory_[0][0]] = 255 else: for itt in range(len(trajectory_)-1): cv2.line(vis_img,(trajectory_[itt][0],trajectory_[itt][1]),(trajectory_[itt+1][0],trajectory_[itt+1][1]),(255,255,255),3) non_zero_coordinates = np.column_stack(np.where(circle_mask != 0)) for coord in non_zero_coordinates: ids_embedding[coord[0], coord[1]] = id_feature[0] ids_embedding = F.avg_pool1d(ids_embedding, kernel_size=2, stride=2) img = new_img # Ensure all images are in RGB format if len(img.shape) == 2: # Grayscale image img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2RGB) elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB) # Convert the numpy array to a PIL image pil_img = Image.fromarray(img) images.append(pil_img) vis_images.append(Image.fromarray(vis_img)) ID_images.append(ids_embedding) return images,ID_images,vis_images def find_largest_inner_rectangle_coordinates(mask_gray): refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL) _, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist) radius = int(maxVal) return maxLoc, radius def save_gifs_side_by_side(batch_output, validation_control_images,output_folder,name = 'none', target_size=(512 , 512),duration=200): flattened_batch_output = batch_output def create_gif(image_list, gif_path, duration=100): pil_images = [validate_and_convert_image(img,target_size=target_size) for img in image_list] pil_images = [img for img in pil_images if img is not None] if pil_images: pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], loop=0, duration=duration) # Creating GIFs for each image list timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") gif_paths = [] # validation_control_images = validation_control_images*255 validation_images, for idx, image_list in enumerate([validation_control_images, flattened_batch_output]): # if idx==0: # continue gif_path = os.path.join(output_folder.replace("vis_gif.gif",""), f"temp_{idx}_{timestamp}.gif") create_gif(image_list, gif_path) gif_paths.append(gif_path) # Function to combine GIFs side by side def combine_gifs_side_by_side(gif_paths, output_path): print(gif_paths) gifs = [Image.open(gif) for gif in gif_paths] # Assuming all gifs have the same frame count and duration frames = [] for frame_idx in range(gifs[0].n_frames): combined_frame = None for gif in gifs: gif.seek(frame_idx) if combined_frame is None: combined_frame = gif.copy() else: combined_frame = get_concat_h(combined_frame, gif.copy()) frames.append(combined_frame) print(gifs[0].info['duration']) frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration) # Helper function to concatenate images horizontally def get_concat_h(im1, im2): dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height))) dst.paste(im1, (0, 0)) dst.paste(im2, (im1.width, 0)) return dst # Combine the GIFs into a single file combined_gif_path = output_folder combine_gifs_side_by_side(gif_paths, combined_gif_path) # Clean up temporary GIFs for gif_path in gif_paths: os.remove(gif_path) return combined_gif_path # Define functions def validate_and_convert_image(image, target_size=(512 , 512)): if image is None: print("Encountered a None image") return None if isinstance(image, torch.Tensor): # Convert PyTorch tensor to PIL Image if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format if image.shape[0] == 1: # Convert single-channel grayscale to RGB image = image.repeat(3, 1, 1) image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() image = Image.fromarray(image) else: print(f"Invalid image tensor shape: {image.shape}") return None elif isinstance(image, Image.Image): # Resize PIL Image image = image.resize(target_size) else: print("Image is not a PIL Image or a PyTorch tensor") return None return image class Drag: def __init__(self, device, args, height, width, model_length): self.device = device self.controlnet = controlnet = DragAnythingSDVModel.from_pretrained(args["DragAnything"]) unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(args["pretrained_model_name_or_path"],subfolder="unet") self.pipeline = StableVideoDiffusionPipeline.from_pretrained(args["pretrained_model_name_or_path"],controlnet=self.controlnet,unet=unet) self.pipeline.enable_model_cpu_offload() self.height = height self.width = width self.args = args self.model_length = model_length def run(self, first_frame_path, tracking_points, inference_batch_size, motion_bucket_id): original_width, original_height=576, 320 input_all_points = tracking_points.constructor_args['value'] resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points] input_drag = torch.zeros(self.model_length - 1, self.height, self.width, 2) for idx,splited_track in enumerate(resized_all_points): if len(splited_track) == 1: # stationary point displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1]) splited_track = tuple([splited_track[0], displacement_point]) # interpolate the track splited_track = interpolate_trajectory(splited_track, self.model_length) splited_track = splited_track[:self.model_length] resized_all_points[idx]=splited_track validation_image = Image.open(first_frame_path).convert('RGB') width, height = validation_image.size validation_image = validation_image.resize((self.width, self.height)) validation_control_images,ids_embedding,vis_images = get_condition(target_size=(self.args["height"] , self.args["width"]),points = resized_all_points, original_size=(self.height , self.width), args = self.args,first_frame = validation_image, side=100,model_id=args["model_DIFT"]) ids_embedding = torch.stack(ids_embedding, dim=0).permute(0, 3, 1, 2) # Inference and saving loop video_frames = self.pipeline(validation_image, validation_control_images[:self.model_length], decode_chunk_size=8,num_frames=self.model_length,motion_bucket_id=motion_bucket_id,controlnet_cond_scale=1.0,height=self.height,width=self.width,ids_embedding=ids_embedding[:self.model_length]).frames vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images] vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images] vis_images = [Image.fromarray(img) for img in vis_images] video_frames = [img for sublist in video_frames for img in sublist] val_save_dir = output_dir+"/"+"vis_gif.gif" save_gifs_side_by_side(video_frames, vis_images[:self.model_length],val_save_dir,target_size=(self.width,self.height),duration=110) # clip = Image.open(val_save_dir) # print(clip.size) return val_save_dir args = { "pretrained_model_name_or_path": "stabilityai/stable-video-diffusion-img2vid", "DragAnything":"./model_out/DragAnything", "model_DIFT":"./utils/pretrained_models/chilloutmix", "validation_image": "./validation_demo/Demo/ship_@", "output_dir": "./validation_demo", "height": 320, "width": 576, "frame_number": 14 # cant be bothered to add the args in myself, just use notepad } with gr.Blocks() as demo: gr.Markdown("""