import os |
import torch |
import datetime |
import numpy as np |
from PIL import Image |
from pipeline.pipeline_svd_DragAnything import StableVideoDiffusionPipeline |
from models.DragAnything import DragAnythingSDVModel |
from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel |
import cv2 |
import re |
from scipy.ndimage import distance_transform_edt |
import torchvision.transforms as T |
import torch.nn.functional as F |
from utils.dift_util import DIFT_Demo, SDFeaturizer |
from torchvision.transforms import PILToTensor |
import json |
def save_gifs_side_by_side(batch_output, validation_control_images,output_folder,name = 'none', target_size=(512 , 512),duration=200): |
flattened_batch_output = batch_output |
def create_gif(image_list, gif_path, duration=100): |
pil_images = [validate_and_convert_image(img,target_size=target_size) for img in image_list] |
pil_images = [img for img in pil_images if img is not None] |
if pil_images: |
pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], loop=0, duration=duration) |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
gif_paths = [] |
for idx, image_list in enumerate([validation_control_images, flattened_batch_output]): |
gif_path = os.path.join(output_folder, f"temp_{idx}_{timestamp}.gif") |
create_gif(image_list, gif_path) |
gif_paths.append(gif_path) |
def combine_gifs_side_by_side(gif_paths, output_path): |
print(gif_paths) |
gifs = [Image.open(gif) for gif in gif_paths] |
frames = [] |
for frame_idx in range(gifs[0].n_frames): |
combined_frame = None |
for gif in gifs: |
gif.seek(frame_idx) |
if combined_frame is None: |
combined_frame = gif.copy() |
else: |
combined_frame = get_concat_h(combined_frame, gif.copy()) |
frames.append(combined_frame) |
print(gifs[0].info['duration']) |
frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration) |
def get_concat_h(im1, im2): |
dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height))) |
dst.paste(im1, (0, 0)) |
dst.paste(im2, (im1.width, 0)) |
return dst |
combined_gif_path = os.path.join(output_folder, f"combined_frames_{name}_{timestamp}.gif") |
combine_gifs_side_by_side(gif_paths, combined_gif_path) |
for gif_path in gif_paths: |
os.remove(gif_path) |
return combined_gif_path |
def validate_and_convert_image(image, target_size=(512 , 512)): |
if image is None: |
print("Encountered a None image") |
return None |
if isinstance(image, torch.Tensor): |
if image.ndim == 3 and image.shape[0] in [1, 3]: |
if image.shape[0] == 1: |
image = image.repeat(3, 1, 1) |
image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() |
image = Image.fromarray(image) |
else: |
print(f"Invalid image tensor shape: {image.shape}") |
return None |
elif isinstance(image, Image.Image): |
image = image.resize(target_size) |
else: |
print("Image is not a PIL Image or a PyTorch tensor") |
return None |
return image |
def create_image_grid(images, rows, cols, target_size=(512 , 512)): |
valid_images = [validate_and_convert_image(img, target_size) for img in images] |
valid_images = [img for img in valid_images if img is not None] |
if not valid_images: |
print("No valid images to create a grid") |
return None |
w, h = target_size |
grid = Image.new('RGB', size=(cols * w, rows * h)) |
for i, image in enumerate(valid_images): |
grid.paste(image, box=((i % cols) * w, (i // cols) * h)) |
return grid |
def tensor_to_pil(tensor): |
""" Convert a PyTorch tensor to a PIL Image. """ |
if len(tensor.shape) == 4: |
images = [Image.fromarray(img.numpy().transpose(1, 2, 0)) for img in tensor] |
else: |
images = Image.fromarray(tensor.numpy().transpose(1, 2, 0)) |
return images |
def save_combined_frames(batch_output, validation_images, validation_control_images, output_folder): |
flattened_batch_output = [img for sublist in batch_output for img in sublist] |
validation_images = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in validation_images] |
validation_control_images = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in validation_control_images] |
flattened_batch_output = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in batch_output] |
validation_images = [img for sublist in validation_images for img in (sublist if isinstance(sublist, list) else [sublist])] |
validation_control_images = [img for sublist in validation_control_images for img in (sublist if isinstance(sublist, list) else [sublist])] |
flattened_batch_output = [img for sublist in flattened_batch_output for img in (sublist if isinstance(sublist, list) else [sublist])] |
combined_frames = validation_images + validation_control_images + flattened_batch_output |
num_images = len(combined_frames) |
cols = 3 |
rows = (num_images + cols - 1) // cols |
grid = create_image_grid(combined_frames, rows, cols, target_size=(512, 512)) |
if grid is not None: |
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
filename = f"combined_frames_{timestamp}.png" |
output_path = os.path.join(output_folder, filename) |
grid.save(output_path) |
else: |
print("Failed to create image grid") |
def load_images_from_folder(folder): |
images = [] |
valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} |
def frame_number(filename): |
matches = re.findall(r'\d+', filename) |
if matches: |
if matches[-1] == '0000' and len(matches) > 1: |
return int(matches[-2]) |
return int(matches[-1]) |
return float('inf') |
sorted_files = sorted(os.listdir(folder), key=frame_number) |
for filename in sorted_files: |
ext = os.path.splitext(filename)[1].lower() |
if ext in valid_extensions: |
img = Image.open(os.path.join(folder, filename)).convert('RGB') |
images.append(img) |
return images |
def gen_gaussian_heatmap(imgSize=200): |
circle_img = np.zeros((imgSize, imgSize), np.float32) |
circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1) |
isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32) |
for i in range(imgSize): |
for j in range(imgSize): |
isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp( |
-1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2))) |
isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask |
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32) |
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8) |
return isotropicGrayscaleImage |
def infer_model(model, image): |
transform = T.Compose([ |
T.Resize((196, 196)), |
T.ToTensor(), |
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
]) |
image = transform(image).unsqueeze(0).cuda() |
cls_token = model(image, is_training=False) |
return cls_token |
def find_largest_inner_rectangle_coordinates(mask_gray): |
refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL) |
_, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist) |
radius = int(maxVal) |
return maxLoc, radius |
def get_ID(images_list,masks_list,dinov2): |
ID_images = [] |
image = images_list |
mask = masks_list |
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
max_contour = max(contours, key=cv2.contourArea) |
x, y, w, h = cv2.boundingRect(max_contour) |
mask = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2RGB) |
image = image * mask |
image = image[y:y+h,x:x+w] |
image = Image.fromarray(image).convert('RGB') |
img_embedding = infer_model(dinov2, image) |
return img_embedding |
def get_dift_ID(feature_map,mask): |
new_feature = [] |
non_zero_coordinates = np.column_stack(np.where(mask != 0)) |
for coord in non_zero_coordinates: |
new_feature.append(feature_map[:, coord[0], coord[1]]) |
stacked_tensor = torch.stack(new_feature, dim=0) |
average_pooled_tensor = torch.mean(stacked_tensor, dim=0) |
return average_pooled_tensor |
def extract_dift_feature(image, dift_model): |
if isinstance(image, Image.Image): |
image = image |
else: |
image = Image.open(image).convert('RGB') |
prompt = '' |
img_tensor = (PILToTensor()(image) / 255.0 - 0.5) * 2 |
dift_feature = dift_model.forward(img_tensor, prompt=prompt, up_ft_index=3,ensemble_size=8) |
return dift_feature |
def get_condition(target_size=(512 , 512), original_size=(512 , 512), args="", first_frame=None, is_mask = False, side=20,model_id=None): |
images = [] |
vis_images = [] |
heatmap = gen_gaussian_heatmap() |
original_size = (original_size[1],original_size[0]) |
size = (target_size[1],target_size[0]) |
latent_size = (int(target_size[1]/8), int(target_size[0]/8)) |
dift_model = SDFeaturizer(sd_id=model_id) |
keyframe_dift = extract_dift_feature(first_frame, dift_model=dift_model) |
ID_images=[] |
ids_list={} |
with open(os.path.join(args["validation_image"],"demo.json"), 'r') as json_file: |
trajectory_json = json.load(json_file) |
mask_list = [] |
trajectory_list = [] |
radius_list = [] |
for index in trajectory_json: |
ann = trajectory_json[index] |
mask_name = ann["mask_name"] |
trajectories = ann["trajectory"] |
trajectories = [[int(i[0]/original_size[0]*size[0]),int(i[1]/original_size[1]*size[1])] for i in trajectories] |
trajectory_list.append(trajectories) |
first_mask = (cv2.imread(os.path.join(args["validation_image"],mask_name))/255).astype(np.uint8) |
first_mask = cv2.cvtColor(first_mask.astype(np.uint8), cv2.COLOR_RGB2GRAY) |
mask_list.append(first_mask) |
mask_322 = cv2.resize(first_mask.astype(np.uint8),(int(target_size[1]), int(target_size[0]))) |
_, radius = find_largest_inner_rectangle_coordinates(mask_322) |
radius_list.append(radius) |
viss = 0 |
if viss: |
mask_list_vis = [cv2.resize(i,(int(target_size[1]), int(target_size[0]))) for i in mask_list] |
vis_first_mask = show_mask(cv2.resize(np.array(first_frame).astype(np.uint8),(int(target_size[1]), int(target_size[0]))), mask_list_vis) |
vis_first_mask = cv2.cvtColor(vis_first_mask, cv2.COLOR_BGR2RGB) |
cv2.imwrite("test.jpg",vis_first_mask) |
assert False |
for idxx,point in enumerate(trajectory_list[0]): |
new_img = np.zeros(target_size, np.uint8) |
vis_img = new_img.copy() |
ids_embedding = torch.zeros((target_size[0], target_size[1], 320)) |
if idxx>= args["frame_number"]: |
break |
for cc,(mask,trajectory,radius) in enumerate(zip(mask_list,trajectory_list,radius_list)): |
center_coordinate = trajectory[idxx] |
trajectory_ = trajectory[:idxx] |
side = min(radius,50) |
if idxx == 0: |
mask_32 = cv2.resize(mask.astype(np.uint8),latent_size) |
if len(np.column_stack(np.where(mask_32 != 0)))==0: |
continue |
ids_list[cc] = get_dift_ID(keyframe_dift[0],mask_32) |
id_feature = ids_list[cc] |
else: |
id_feature = ids_list[cc] |
circle_img = np.zeros((target_size[0], target_size[1]), np.float32) |
circle_mask = cv2.circle(circle_img, (center_coordinate[0],center_coordinate[1]), side, 1, -1) |
y1 = max(center_coordinate[1]-side,0) |
y2 = min(center_coordinate[1]+side,target_size[0]-1) |
x1 = max(center_coordinate[0]-side,0) |
x2 = min(center_coordinate[0]+side,target_size[1]-1) |
if x2-x1>3 and y2-y1>3: |
need_map = cv2.resize(heatmap, (x2-x1, y2-y1)) |
new_img[y1:y2,x1:x2] = need_map.copy() |
if cc>=0: |
vis_img[y1:y2,x1:x2] = need_map.copy() |
if len(trajectory_) == 1: |
vis_img[trajectory_[0][1],trajectory_[0][0]] = 255 |
else: |
for itt in range(len(trajectory_)-1): |
cv2.line(vis_img,(trajectory_[itt][0],trajectory_[itt][1]),(trajectory_[itt+1][0],trajectory_[itt+1][1]),(255,255,255),3) |
non_zero_coordinates = np.column_stack(np.where(circle_mask != 0)) |
for coord in non_zero_coordinates: |
ids_embedding[coord[0], coord[1]] = id_feature[0] |
ids_embedding = F.avg_pool1d(ids_embedding, kernel_size=2, stride=2) |
img = new_img |
if len(img.shape) == 2: |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2RGB) |
elif len(img.shape) == 3 and img.shape[2] == 3: |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB) |
pil_img = Image.fromarray(img) |
images.append(pil_img) |
vis_images.append(Image.fromarray(vis_img)) |
ID_images.append(ids_embedding) |
return images,ID_images,vis_images |
def convert_list_bgra_to_rgba(image_list): |
""" |
Convert a list of PIL Image objects from BGRA to RGBA format. |
Parameters: |
image_list (list of PIL.Image.Image): A list of images in BGRA format. |
Returns: |
list of PIL.Image.Image: The list of images converted to RGBA format. |
""" |
rgba_images = [] |
for image in image_list: |
if image.mode == 'RGBA' or image.mode == 'BGRA': |
b, g, r, a = image.split() |
converted_image = Image.merge("RGBA", (r, g, b, a)) |
else: |
b, g, r = image.split() |
converted_image = Image.merge("RGB", (r, g, b)) |
rgba_images.append(converted_image) |
return rgba_images |
def show_mask(image, masks, random_color=False): |
if random_color: |
color = np.concatenate([np.random.random(3)], axis=0) |
h, w = mask.shape[:2] |
color_a = np.concatenate([np.random.random(3)*255], axis=0) |
mask_image = mask.reshape(h, w, 1) * color_a.reshape(1, 1, -1) |
else: |
h, w = masks[0].shape[:2] |
mask_image = 0 |
for idx,mask in enumerate(masks): |
if idx!=1 and idx!=0: |
continue |
color = np.concatenate([np.random.random(3)*255], axis=0) |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + mask_image |
return (np.array(image).copy()*0.4+mask_image*0.6).astype(np.uint8) |
if __name__ == "__main__": |
args = { |
"pretrained_model_name_or_path": "stabilityai/stable-video-diffusion-img2vid", |
"DragAnything":"./model_out/DragAnything", |
"model_DIFT":"./utils/pretrained_models/chilloutmix", |
"validation_image": "./validation_demo/Demo/ship_@", |
"output_dir": "./validation_demo", |
"height": 320, |
"width": 576, |
"frame_number": 20 |
} |
controlnet = controlnet = DragAnythingSDVModel.from_pretrained(args["DragAnything"]) |
unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(args["pretrained_model_name_or_path"],subfolder="unet") |
pipeline = StableVideoDiffusionPipeline.from_pretrained(args["pretrained_model_name_or_path"],controlnet=controlnet,unet=unet) |
pipeline.enable_model_cpu_offload() |
validation_image = Image.open(os.path.join(args["validation_image"],"demo.jpg")).convert('RGB') |
width, height = validation_image.size |
validation_image = validation_image.resize((args["width"], args["height"])) |
validation_control_images,ids_embedding,vis_images = get_condition(target_size=(args["height"] , args["width"]), |
original_size=(height , width), |
args = args,first_frame = validation_image, |
side=100,model_id=args["model_DIFT"]) |
ids_embedding = torch.stack(ids_embedding, dim=0).permute(0, 3, 1, 2) |
val_save_dir = os.path.join(args["output_dir"], "saved_video") |
os.makedirs(val_save_dir, exist_ok=True) |
video_frames = pipeline(validation_image, validation_control_images[:args["frame_number"]], decode_chunk_size=8,num_frames=args["frame_number"],motion_bucket_id=180,controlnet_cond_scale=1.0,height=args["height"],width=args["width"],ids_embedding=ids_embedding[:args["frame_number"]]).frames |
vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images] |
vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images] |
vis_images = [Image.fromarray(img) for img in vis_images] |
video_frames = [img for sublist in video_frames for img in sublist] |
save_gifs_side_by_side(video_frames, vis_images[:args["frame_number"]],val_save_dir,target_size=(width,height),duration=110) |