Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
from src.utils import ( | |
CenterPadCrop_numpy, | |
Distortion_with_flow_cpu, | |
Distortion_with_flow_gpu, | |
Normalize, | |
RGB2Lab, | |
ToTensor, | |
Normalize, | |
RGB2Lab, | |
ToTensor, | |
CenterPad, | |
read_flow, | |
SquaredPadding | |
) | |
import torch | |
import torch.utils.data as data | |
import torchvision.transforms as transforms | |
from numpy import random | |
import os | |
from PIL import Image | |
from scipy.ndimage.filters import gaussian_filter | |
from scipy.ndimage import map_coordinates | |
import glob | |
def image_loader(path): | |
with open(path, "rb") as f: | |
with Image.open(f) as img: | |
return img.convert("RGB") | |
class CenterCrop(object): | |
""" | |
center crop the numpy array | |
""" | |
def __init__(self, image_size): | |
self.h0, self.w0 = image_size | |
def __call__(self, input_numpy): | |
if input_numpy.ndim == 3: | |
h, w, channel = input_numpy.shape | |
output_numpy = np.zeros((self.h0, self.w0, channel)) | |
output_numpy = input_numpy[ | |
(h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0, : | |
] | |
else: | |
h, w = input_numpy.shape | |
output_numpy = np.zeros((self.h0, self.w0)) | |
output_numpy = input_numpy[ | |
(h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0 | |
] | |
return output_numpy | |
class VideosDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
video_data_root, | |
flow_data_root, | |
mask_data_root, | |
imagenet_folder, | |
annotation_file_path, | |
image_size, | |
num_refs=5, # max = 20 | |
image_transform=None, | |
real_reference_probability=1, | |
nonzero_placeholder_probability=0.5, | |
): | |
self.video_data_root = video_data_root | |
self.flow_data_root = flow_data_root | |
self.mask_data_root = mask_data_root | |
self.imagenet_folder = imagenet_folder | |
self.image_transform = image_transform | |
self.CenterPad = CenterPad(image_size) | |
self.Resize = transforms.Resize(image_size) | |
self.ToTensor = ToTensor() | |
self.CenterCrop = transforms.CenterCrop(image_size) | |
self.SquaredPadding = SquaredPadding(image_size[0]) | |
self.num_refs = num_refs | |
assert os.path.exists(self.video_data_root), "find no video dataroot" | |
assert os.path.exists(self.flow_data_root), "find no flow dataroot" | |
assert os.path.exists(self.imagenet_folder), "find no imagenet folder" | |
# self.epoch = epoch | |
self.image_pairs = pd.read_csv(annotation_file_path, dtype=str) | |
self.real_len = len(self.image_pairs) | |
# self.image_pairs = pd.concat([self.image_pairs] * self.epoch, ignore_index=True) | |
self.real_reference_probability = real_reference_probability | |
self.nonzero_placeholder_probability = nonzero_placeholder_probability | |
print("##### parsing image pairs in %s: %d pairs #####" % (video_data_root, self.__len__())) | |
def __getitem__(self, index): | |
( | |
video_name, | |
prev_frame, | |
current_frame, | |
flow_forward_name, | |
mask_name, | |
reference_1_name, | |
reference_2_name, | |
reference_3_name, | |
reference_4_name, | |
reference_5_name | |
) = self.image_pairs.iloc[index, :5+self.num_refs].values.tolist() | |
video_path = os.path.join(self.video_data_root, video_name) | |
flow_path = os.path.join(self.flow_data_root, video_name) | |
mask_path = os.path.join(self.mask_data_root, video_name) | |
prev_frame_path = os.path.join(video_path, prev_frame) | |
current_frame_path = os.path.join(video_path, current_frame) | |
list_frame_path = glob.glob(os.path.join(video_path, '*')) | |
list_frame_path.sort() | |
reference_1_path = os.path.join(self.imagenet_folder, reference_1_name) | |
reference_2_path = os.path.join(self.imagenet_folder, reference_2_name) | |
reference_3_path = os.path.join(self.imagenet_folder, reference_3_name) | |
reference_4_path = os.path.join(self.imagenet_folder, reference_4_name) | |
reference_5_path = os.path.join(self.imagenet_folder, reference_5_name) | |
flow_forward_path = os.path.join(flow_path, flow_forward_name) | |
mask_path = os.path.join(mask_path, mask_name) | |
#reference_gt_1_path = prev_frame_path | |
#reference_gt_2_path = current_frame_path | |
try: | |
I1 = Image.open(prev_frame_path).convert("RGB") | |
I2 = Image.open(current_frame_path).convert("RGB") | |
try: | |
I_reference_video = Image.open(list_frame_path[0]).convert("RGB") # Get first frame | |
except: | |
I_reference_video = Image.open(current_frame_path).convert("RGB") # Get current frame if error | |
reference_list = [reference_1_path, reference_2_path, reference_3_path, reference_4_path, reference_5_path] | |
while reference_list: # run until getting the colorized reference | |
reference_path = random.choice(reference_list) | |
I_reference_video_real = Image.open(reference_path) | |
if I_reference_video_real.mode == 'L': | |
reference_list.remove(reference_path) | |
else: | |
break | |
if not reference_list: | |
I_reference_video_real = I_reference_video | |
flow_forward = read_flow(flow_forward_path) # numpy | |
mask = Image.open(mask_path) # PIL | |
mask = self.Resize(mask) | |
mask = np.array(mask) | |
# mask = self.SquaredPadding(mask, return_pil=False, return_paddings=False) | |
# binary mask | |
mask[mask < 240] = 0 | |
mask[mask >= 240] = 1 | |
mask = self.ToTensor(mask) | |
# transform | |
I1 = self.image_transform(I1) | |
I2 = self.image_transform(I2) | |
I_reference_video = self.image_transform(I_reference_video) | |
I_reference_video_real = self.image_transform(I_reference_video_real) | |
flow_forward = self.ToTensor(flow_forward) | |
flow_forward = self.Resize(flow_forward)#, return_pil=False, return_paddings=False, dtype=np.float32) | |
if np.random.random() < self.real_reference_probability: | |
I_reference_output = I_reference_video_real # Use reference from imagenet | |
placeholder = torch.zeros_like(I1) | |
self_ref_flag = torch.zeros_like(I1) | |
else: | |
I_reference_output = I_reference_video # Use reference from ground truth | |
placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1) | |
self_ref_flag = torch.ones_like(I1) | |
outputs = [ | |
I1, | |
I2, | |
I_reference_output, | |
flow_forward, | |
mask, | |
placeholder, | |
self_ref_flag, | |
video_name + prev_frame, | |
video_name + current_frame, | |
reference_path | |
] | |
except Exception as e: | |
print("error in reading image pair: %s" % str(self.image_pairs[index])) | |
print(e) | |
return self.__getitem__(np.random.randint(0, len(self.image_pairs))) | |
return outputs | |
def __len__(self): | |
return len(self.image_pairs) | |
def parse_imgnet_images(pairs_file): | |
pairs = [] | |
with open(pairs_file, "r") as f: | |
lines = f.readlines() | |
for line in lines: | |
line = line.strip().split("|") | |
image_a = line[0] | |
image_b = line[1] | |
pairs.append((image_a, image_b)) | |
return pairs | |
class VideosDataset_ImageNet(data.Dataset): | |
def __init__( | |
self, | |
imagenet_data_root, | |
pairs_file, | |
image_size, | |
transforms_imagenet=None, | |
distortion_level=3, | |
brightnessjitter=0, | |
nonzero_placeholder_probability=0.5, | |
extra_reference_transform=None, | |
real_reference_probability=1, | |
distortion_device='cpu' | |
): | |
self.imagenet_data_root = imagenet_data_root | |
self.image_pairs = pd.read_csv(pairs_file, names=['i1', 'i2']) | |
self.transforms_imagenet_raw = transforms_imagenet | |
self.extra_reference_transform = transforms.Compose(extra_reference_transform) | |
self.real_reference_probability = real_reference_probability | |
self.transforms_imagenet = transforms.Compose(transforms_imagenet) | |
self.image_size = image_size | |
self.real_len = len(self.image_pairs) | |
self.distortion_level = distortion_level | |
self.distortion_transform = Distortion_with_flow_cpu() if distortion_device == 'cpu' else Distortion_with_flow_gpu() | |
self.brightnessjitter = brightnessjitter | |
self.flow_transform = transforms.Compose([CenterPadCrop_numpy(self.image_size), ToTensor()]) | |
self.nonzero_placeholder_probability = nonzero_placeholder_probability | |
self.ToTensor = ToTensor() | |
self.Normalize = Normalize() | |
print("##### parsing imageNet pairs in %s: %d pairs #####" % (imagenet_data_root, self.__len__())) | |
def __getitem__(self, index): | |
pa, pb = self.image_pairs.iloc[index].values.tolist() | |
if np.random.random() > 0.5: | |
pa, pb = pb, pa | |
image_a_path = os.path.join(self.imagenet_data_root, pa) | |
image_b_path = os.path.join(self.imagenet_data_root, pb) | |
I1 = image_loader(image_a_path) | |
I2 = I1 | |
I_reference_video = I1 | |
I_reference_video_real = image_loader(image_b_path) | |
# print("i'm here get image 2") | |
# generate the flow | |
alpha = np.random.rand() * self.distortion_level | |
distortion_range = 50 | |
random_state = np.random.RandomState(None) | |
shape = self.image_size[0], self.image_size[1] | |
# dx: flow on the vertical direction; dy: flow on the horizontal direction | |
forward_dx = ( | |
gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000 | |
) | |
forward_dy = ( | |
gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000 | |
) | |
# print("i'm here get image 3") | |
for transform in self.transforms_imagenet_raw: | |
if type(transform) is RGB2Lab: | |
I1_raw = I1 | |
I1 = transform(I1) | |
for transform in self.transforms_imagenet_raw: | |
if type(transform) is RGB2Lab: | |
I2 = self.distortion_transform(I2, forward_dx, forward_dy) | |
I2_raw = I2 | |
I2 = transform(I2) | |
# print("i'm here get image 4") | |
I2[0:1, :, :] = I2[0:1, :, :] + torch.randn(1) * self.brightnessjitter | |
I_reference_video = self.extra_reference_transform(I_reference_video) | |
for transform in self.transforms_imagenet_raw: | |
I_reference_video = transform(I_reference_video) | |
I_reference_video_real = self.transforms_imagenet(I_reference_video_real) | |
# print("i'm here get image 5") | |
flow_forward_raw = np.stack((forward_dy, forward_dx), axis=-1) | |
flow_forward = self.flow_transform(flow_forward_raw) | |
# update the mask for the pixels on the border | |
grid_x, grid_y = np.meshgrid(np.arange(self.image_size[0]), np.arange(self.image_size[1]), indexing="ij") | |
grid = np.stack((grid_y, grid_x), axis=-1) | |
grid_warp = grid + flow_forward_raw | |
location_y = grid_warp[:, :, 0].flatten() | |
location_x = grid_warp[:, :, 1].flatten() | |
I2_raw = np.array(I2_raw).astype(float) | |
I21_r = map_coordinates(I2_raw[:, :, 0], np.stack((location_x, location_y)), cval=-1).reshape( | |
(self.image_size[0], self.image_size[1]) | |
) | |
I21_g = map_coordinates(I2_raw[:, :, 1], np.stack((location_x, location_y)), cval=-1).reshape( | |
(self.image_size[0], self.image_size[1]) | |
) | |
I21_b = map_coordinates(I2_raw[:, :, 2], np.stack((location_x, location_y)), cval=-1).reshape( | |
(self.image_size[0], self.image_size[1]) | |
) | |
I21_raw = np.stack((I21_r, I21_g, I21_b), axis=2) | |
mask = np.ones((self.image_size[0], self.image_size[1])) | |
mask[(I21_raw[:, :, 0] == -1) & (I21_raw[:, :, 1] == -1) & (I21_raw[:, :, 2] == -1)] = 0 | |
mask[abs(I21_raw - I1_raw).sum(axis=-1) > 50] = 0 | |
mask = self.ToTensor(mask) | |
# print("i'm here get image 6") | |
if np.random.random() < self.real_reference_probability: | |
I_reference_output = I_reference_video_real | |
placeholder = torch.zeros_like(I1) | |
self_ref_flag = torch.zeros_like(I1) | |
else: | |
I_reference_output = I_reference_video | |
placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1) | |
self_ref_flag = torch.ones_like(I1) | |
# except Exception as e: | |
# if combo_path is not None: | |
# print("problem in ", combo_path) | |
# print("problem in, ", image_a_path) | |
# print(e) | |
# return self.__getitem__(np.random.randint(0, len(self.image_pairs))) | |
# print("i'm here get image 7") | |
return [I1, I2, I_reference_output, flow_forward, mask, placeholder, self_ref_flag, "holder", pb, pa] | |
def __len__(self): | |
return len(self.image_pairs) |