|
import os, io, csv, math, random |
|
import numpy as np |
|
from einops import rearrange |
|
|
|
import torch |
|
from decord import VideoReader |
|
import cv2 |
|
from scipy.ndimage import distance_transform_edt |
|
import torchvision.transforms as transforms |
|
from torch.utils.data.dataset import Dataset |
|
|
|
|
|
from PIL import Image |
|
|
|
def pil_image_to_numpy(image, is_maks = False, index = 1,size=256): |
|
"""Convert a PIL image to a NumPy array.""" |
|
|
|
if is_maks: |
|
image = image.resize((size, size)) |
|
|
|
|
|
return np.array(image) |
|
else: |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
image = image.resize((size, size)) |
|
return np.array(image) |
|
|
|
def numpy_to_pt(images: np.ndarray, is_mask=False) -> torch.FloatTensor: |
|
"""Convert a NumPy image to a PyTorch tensor.""" |
|
if images.ndim == 3: |
|
images = images[..., None] |
|
images = torch.from_numpy(images.transpose(0, 3, 1, 2)) |
|
if is_mask: |
|
return images.float() |
|
else: |
|
return images.float() / 255 |
|
|
|
|
|
def find_largest_inner_rectangle_coordinates(mask_gray): |
|
|
|
refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL) |
|
_, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist) |
|
radius = int(maxVal) |
|
|
|
return maxLoc, radius |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class YoutubeVos(Dataset): |
|
def __init__( |
|
self,video_folder,ann_folder,feature_folder, |
|
sample_size=512, sample_stride=4, sample_n_frames=14, |
|
): |
|
|
|
self.dataset = [i for i in os.listdir(feature_folder)] |
|
self.length = len(self.dataset) |
|
print(f"data scale: {self.length}") |
|
random.shuffle(self.dataset) |
|
self.video_folder = video_folder |
|
self.sample_stride = sample_stride |
|
self.sample_n_frames = sample_n_frames |
|
self.ann_folder = ann_folder |
|
self.heatmap = self.gen_gaussian_heatmap() |
|
self.feature_folder=feature_folder |
|
self.sample_size = sample_size |
|
|
|
print("length",len(self.dataset)) |
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) |
|
|
|
print("sample size",sample_size) |
|
self.pixel_transforms = transforms.Compose([ |
|
|
|
transforms.Resize(sample_size), |
|
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
]) |
|
|
|
self.idtransform = transforms.Compose([ |
|
transforms.Resize((196, 196)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
]) |
|
|
|
|
|
|
|
|
|
def center_crop(self,img): |
|
h, w = img.shape[-2:] |
|
min_dim = min(h, w) |
|
top = (h - min_dim) // 2 |
|
left = (w - min_dim) // 2 |
|
return img[..., top:top+min_dim, left:left+min_dim] |
|
|
|
def gen_gaussian_heatmap(self,imgSize=200): |
|
circle_img = np.zeros((imgSize, imgSize), np.float32) |
|
circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1) |
|
|
|
isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32) |
|
|
|
|
|
for i in range(imgSize): |
|
for j in range(imgSize): |
|
isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp( |
|
-1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2))) |
|
|
|
isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask |
|
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32) |
|
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8) |
|
|
|
|
|
return isotropicGrayscaleImage |
|
|
|
def calculate_center_coordinates(self,numpy_images,masks,ids, side=20): |
|
center_coordinates = [] |
|
|
|
ids_list = {} |
|
|
|
for index_mask, mask in enumerate(masks): |
|
new_img = np.zeros((self.sample_size, self.sample_size), np.float32) |
|
ids_embedding = torch.zeros((self.sample_size, self.sample_size, 1024)) |
|
|
|
for index in ids[1:]: |
|
mask_array = (np.array(mask)==index)*1 |
|
|
|
|
|
try: |
|
center_coordinate,_ = find_largest_inner_rectangle_coordinates(mask_array) |
|
except: |
|
continue |
|
print("find_largest_inner_rectangle_coordinates error") |
|
|
|
x1 = max(center_coordinate[0]-side,0) |
|
x2 = min(center_coordinate[0]+side,self.sample_size-1) |
|
y1 = max(center_coordinate[1]-side,0) |
|
y2 = min(center_coordinate[1]+side,self.sample_size-1) |
|
|
|
|
|
|
|
|
|
|
|
need_map = cv2.resize(self.heatmap, (x2-x1, y2-y1)) |
|
new_img[y1:y2,x1:x2] = need_map |
|
|
|
|
|
|
|
|
|
|
|
|
|
new_img = cv2.cvtColor(new_img.astype(np.uint8), cv2.COLOR_GRAY2RGB) |
|
|
|
center_coordinates.append(new_img) |
|
|
|
return center_coordinates |
|
|
|
def get_ID(self,images_list,masks_list): |
|
|
|
ID_images = [] |
|
|
|
|
|
image = images_list[0] |
|
mask = masks_list |
|
|
|
|
|
try: |
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
x, y, w, h = cv2.boundingRect(contours[0]) |
|
|
|
mask = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2RGB) |
|
image = image * mask |
|
|
|
image = image[y:y+h,x:x+w] |
|
except: |
|
pass |
|
print("cv2.findContours error") |
|
|
|
|
|
|
|
image = Image.fromarray(image).convert('RGB') |
|
image = self.idtransform(image).unsqueeze(0).to(dtype=torch.float16) |
|
image.to(self.device) |
|
|
|
|
|
print(cls_token.shape) |
|
assert False |
|
|
|
|
|
|
|
|
|
|
|
return ID_images |
|
|
|
def get_batch(self, idx): |
|
def sort_frames(frame_name): |
|
return int(frame_name.split('.')[0]) |
|
|
|
while True: |
|
videoid = self.dataset[idx] |
|
|
|
|
|
preprocessed_dir = os.path.join(self.video_folder, videoid) |
|
ann_folder = os.path.join(self.ann_folder, videoid) |
|
feature_folder_file = os.path.join(self.feature_folder, videoid) |
|
|
|
if not os.path.exists(ann_folder): |
|
idx = random.randint(0, len(self.dataset) - 1) |
|
print("os.path.exists(ann_folder), error") |
|
continue |
|
|
|
|
|
image_files = sorted(os.listdir(preprocessed_dir), key=sort_frames)[:self.sample_n_frames] |
|
depth_files = sorted(os.listdir(ann_folder), key=sort_frames)[:self.sample_n_frames] |
|
feature_file = sorted(os.listdir(feature_folder_file), key=sort_frames)[:self.sample_n_frames] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
numpy_images = np.array([pil_image_to_numpy(Image.open(os.path.join(preprocessed_dir, img)),size=self.sample_size) for img in image_files]) |
|
pixel_values = numpy_to_pt(numpy_images) |
|
|
|
|
|
feature_images = np.array([np.array(torch.load(os.path.join(feature_folder_file, img))) for img in feature_file]) |
|
|
|
feature_images = torch.from_numpy(feature_images.transpose(0, 3, 1, 2)) |
|
|
|
|
|
mask = Image.open(os.path.join(ann_folder, depth_files[0])).convert('P') |
|
ids = [i for i in np.unique(mask)] |
|
if len(ids)==1: |
|
idx = random.randint(0, len(self.dataset) - 1) |
|
print("len(ids), error") |
|
continue |
|
|
|
numpy_depth_images = np.array([pil_image_to_numpy(Image.open(os.path.join(ann_folder, df)).convert('P'),True,ids,size=self.sample_size) for df in depth_files]) |
|
heatmap_pixel_values = self.calculate_center_coordinates(numpy_images,numpy_depth_images,ids) |
|
|
|
heatmap_pixel_values = np.array(heatmap_pixel_values) |
|
|
|
|
|
mask_pixel_values = numpy_to_pt(numpy_depth_images,True) |
|
heatmap_pixel_values = numpy_to_pt(heatmap_pixel_values,True) |
|
|
|
|
|
|
|
|
|
motion_values = 180 |
|
|
|
|
|
|
|
return pixel_values, mask_pixel_values, motion_values, heatmap_pixel_values, feature_images |
|
|
|
|
|
|
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def coordinates_normalize(self,center_coordinates): |
|
first_point = center_coordinates[0] |
|
center_coordinates = [one-first_point for one in center_coordinates] |
|
|
|
return center_coordinates |
|
|
|
def normalize(self, images): |
|
""" |
|
Normalize an image array to [-1,1]. |
|
""" |
|
return 2.0 * images - 1.0 |
|
|
|
def normalize_sam(self, images): |
|
""" |
|
Normalize an image array to [-1,1]. |
|
""" |
|
return (images - torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1))/torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) |
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
pixel_values, mask_pixel_values,motion_values,heatmap_pixel_values,feature_images = self.get_batch(idx) |
|
|
|
pixel_values = self.normalize(pixel_values) |
|
|
|
sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, |
|
motion_values=motion_values,heatmap_pixel_values=heatmap_pixel_values,Id_Images=feature_images) |
|
return sample |
|
|
|
|
|
|
|
def load_dinov2(): |
|
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14').cuda() |
|
dinov2_vitl14.eval() |
|
|
|
return dinov2_vitl14 |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
dino = load_dinov2() |
|
dino.to(dtype=torch.float16) |
|
|
|
dataset = YoutubeVos( |
|
video_folder = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/ref-youtube-vos/train/JPEGImages", |
|
ann_folder = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/ref-youtube-vos/train/Annotations", |
|
feature_folder = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/ref-youtube-vos/train/embedding", |
|
sample_size=256, |
|
sample_stride=1, sample_n_frames=16 |
|
) |
|
|
|
|
|
inverse_process = transforms.Compose([ |
|
transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]), |
|
]) |
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=10,) |
|
for idx, batch in enumerate(dataloader): |
|
images = ((batch["pixel_values"][0].permute(0,2,3,1)+1)/2)*255 |
|
masks = batch["mask_pixel_values"][0].permute(0,2,3,1)*255 |
|
heatmaps = batch["heatmap_pixel_values"][0].permute(0,2,3,1) |
|
|
|
|
|
|
|
print(batch["pixel_values"].shape) |
|
|
|
for i in range(images.shape[0]): |
|
image = images[i].numpy().astype(np.uint8) |
|
|
|
|
|
|
|
|
|
mask = masks[i].numpy() |
|
heatmap = heatmaps[i].numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
print(np.unique(mask)) |
|
|
|
cv2.imwrite("./vis/image_{}.jpg".format(i), image) |
|
|
|
cv2.imwrite("./vis/mask_{}.jpg".format(i), mask.astype(np.uint8)) |
|
cv2.imwrite("./vis/heatmap_{}.jpg".format(i), heatmap.astype(np.uint8)) |
|
cv2.imwrite("./vis/{}.jpg".format(i), heatmap.astype(np.uint8)*0.5+image*0.5) |
|
|
|
break |