SwinTExCo / src /data /dataloader.py
duongttr's picture
Update new app
3d85088
raw
history blame
14 kB
import numpy as np
import pandas as pd
from src.utils import (
CenterPadCrop_numpy,
Distortion_with_flow_cpu,
Distortion_with_flow_gpu,
Normalize,
RGB2Lab,
ToTensor,
Normalize,
RGB2Lab,
ToTensor,
CenterPad,
read_flow,
SquaredPadding
)
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from numpy import random
import os
from PIL import Image
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage import map_coordinates
import glob
def image_loader(path):
with open(path, "rb") as f:
with Image.open(f) as img:
return img.convert("RGB")
class CenterCrop(object):
"""
center crop the numpy array
"""
def __init__(self, image_size):
self.h0, self.w0 = image_size
def __call__(self, input_numpy):
if input_numpy.ndim == 3:
h, w, channel = input_numpy.shape
output_numpy = np.zeros((self.h0, self.w0, channel))
output_numpy = input_numpy[
(h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0, :
]
else:
h, w = input_numpy.shape
output_numpy = np.zeros((self.h0, self.w0))
output_numpy = input_numpy[
(h - self.h0) // 2 : (h - self.h0) // 2 + self.h0, (w - self.w0) // 2 : (w - self.w0) // 2 + self.w0
]
return output_numpy
class VideosDataset(torch.utils.data.Dataset):
def __init__(
self,
video_data_root,
flow_data_root,
mask_data_root,
imagenet_folder,
annotation_file_path,
image_size,
num_refs=5, # max = 20
image_transform=None,
real_reference_probability=1,
nonzero_placeholder_probability=0.5,
):
self.video_data_root = video_data_root
self.flow_data_root = flow_data_root
self.mask_data_root = mask_data_root
self.imagenet_folder = imagenet_folder
self.image_transform = image_transform
self.CenterPad = CenterPad(image_size)
self.Resize = transforms.Resize(image_size)
self.ToTensor = ToTensor()
self.CenterCrop = transforms.CenterCrop(image_size)
self.SquaredPadding = SquaredPadding(image_size[0])
self.num_refs = num_refs
assert os.path.exists(self.video_data_root), "find no video dataroot"
assert os.path.exists(self.flow_data_root), "find no flow dataroot"
assert os.path.exists(self.imagenet_folder), "find no imagenet folder"
# self.epoch = epoch
self.image_pairs = pd.read_csv(annotation_file_path, dtype=str)
self.real_len = len(self.image_pairs)
# self.image_pairs = pd.concat([self.image_pairs] * self.epoch, ignore_index=True)
self.real_reference_probability = real_reference_probability
self.nonzero_placeholder_probability = nonzero_placeholder_probability
print("##### parsing image pairs in %s: %d pairs #####" % (video_data_root, self.__len__()))
def __getitem__(self, index):
(
video_name,
prev_frame,
current_frame,
flow_forward_name,
mask_name,
reference_1_name,
reference_2_name,
reference_3_name,
reference_4_name,
reference_5_name
) = self.image_pairs.iloc[index, :5+self.num_refs].values.tolist()
video_path = os.path.join(self.video_data_root, video_name)
flow_path = os.path.join(self.flow_data_root, video_name)
mask_path = os.path.join(self.mask_data_root, video_name)
prev_frame_path = os.path.join(video_path, prev_frame)
current_frame_path = os.path.join(video_path, current_frame)
list_frame_path = glob.glob(os.path.join(video_path, '*'))
list_frame_path.sort()
reference_1_path = os.path.join(self.imagenet_folder, reference_1_name)
reference_2_path = os.path.join(self.imagenet_folder, reference_2_name)
reference_3_path = os.path.join(self.imagenet_folder, reference_3_name)
reference_4_path = os.path.join(self.imagenet_folder, reference_4_name)
reference_5_path = os.path.join(self.imagenet_folder, reference_5_name)
flow_forward_path = os.path.join(flow_path, flow_forward_name)
mask_path = os.path.join(mask_path, mask_name)
#reference_gt_1_path = prev_frame_path
#reference_gt_2_path = current_frame_path
try:
I1 = Image.open(prev_frame_path).convert("RGB")
I2 = Image.open(current_frame_path).convert("RGB")
try:
I_reference_video = Image.open(list_frame_path[0]).convert("RGB") # Get first frame
except:
I_reference_video = Image.open(current_frame_path).convert("RGB") # Get current frame if error
reference_list = [reference_1_path, reference_2_path, reference_3_path, reference_4_path, reference_5_path]
while reference_list: # run until getting the colorized reference
reference_path = random.choice(reference_list)
I_reference_video_real = Image.open(reference_path)
if I_reference_video_real.mode == 'L':
reference_list.remove(reference_path)
else:
break
if not reference_list:
I_reference_video_real = I_reference_video
flow_forward = read_flow(flow_forward_path) # numpy
mask = Image.open(mask_path) # PIL
mask = self.Resize(mask)
mask = np.array(mask)
# mask = self.SquaredPadding(mask, return_pil=False, return_paddings=False)
# binary mask
mask[mask < 240] = 0
mask[mask >= 240] = 1
mask = self.ToTensor(mask)
# transform
I1 = self.image_transform(I1)
I2 = self.image_transform(I2)
I_reference_video = self.image_transform(I_reference_video)
I_reference_video_real = self.image_transform(I_reference_video_real)
flow_forward = self.ToTensor(flow_forward)
flow_forward = self.Resize(flow_forward)#, return_pil=False, return_paddings=False, dtype=np.float32)
if np.random.random() < self.real_reference_probability:
I_reference_output = I_reference_video_real # Use reference from imagenet
placeholder = torch.zeros_like(I1)
self_ref_flag = torch.zeros_like(I1)
else:
I_reference_output = I_reference_video # Use reference from ground truth
placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1)
self_ref_flag = torch.ones_like(I1)
outputs = [
I1,
I2,
I_reference_output,
flow_forward,
mask,
placeholder,
self_ref_flag,
video_name + prev_frame,
video_name + current_frame,
reference_path
]
except Exception as e:
print("error in reading image pair: %s" % str(self.image_pairs[index]))
print(e)
return self.__getitem__(np.random.randint(0, len(self.image_pairs)))
return outputs
def __len__(self):
return len(self.image_pairs)
def parse_imgnet_images(pairs_file):
pairs = []
with open(pairs_file, "r") as f:
lines = f.readlines()
for line in lines:
line = line.strip().split("|")
image_a = line[0]
image_b = line[1]
pairs.append((image_a, image_b))
return pairs
class VideosDataset_ImageNet(data.Dataset):
def __init__(
self,
imagenet_data_root,
pairs_file,
image_size,
transforms_imagenet=None,
distortion_level=3,
brightnessjitter=0,
nonzero_placeholder_probability=0.5,
extra_reference_transform=None,
real_reference_probability=1,
distortion_device='cpu'
):
self.imagenet_data_root = imagenet_data_root
self.image_pairs = pd.read_csv(pairs_file, names=['i1', 'i2'])
self.transforms_imagenet_raw = transforms_imagenet
self.extra_reference_transform = transforms.Compose(extra_reference_transform)
self.real_reference_probability = real_reference_probability
self.transforms_imagenet = transforms.Compose(transforms_imagenet)
self.image_size = image_size
self.real_len = len(self.image_pairs)
self.distortion_level = distortion_level
self.distortion_transform = Distortion_with_flow_cpu() if distortion_device == 'cpu' else Distortion_with_flow_gpu()
self.brightnessjitter = brightnessjitter
self.flow_transform = transforms.Compose([CenterPadCrop_numpy(self.image_size), ToTensor()])
self.nonzero_placeholder_probability = nonzero_placeholder_probability
self.ToTensor = ToTensor()
self.Normalize = Normalize()
print("##### parsing imageNet pairs in %s: %d pairs #####" % (imagenet_data_root, self.__len__()))
def __getitem__(self, index):
pa, pb = self.image_pairs.iloc[index].values.tolist()
if np.random.random() > 0.5:
pa, pb = pb, pa
image_a_path = os.path.join(self.imagenet_data_root, pa)
image_b_path = os.path.join(self.imagenet_data_root, pb)
I1 = image_loader(image_a_path)
I2 = I1
I_reference_video = I1
I_reference_video_real = image_loader(image_b_path)
# print("i'm here get image 2")
# generate the flow
alpha = np.random.rand() * self.distortion_level
distortion_range = 50
random_state = np.random.RandomState(None)
shape = self.image_size[0], self.image_size[1]
# dx: flow on the vertical direction; dy: flow on the horizontal direction
forward_dx = (
gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000
)
forward_dy = (
gaussian_filter((random_state.rand(*shape) * 2 - 1), distortion_range, mode="constant", cval=0) * alpha * 1000
)
# print("i'm here get image 3")
for transform in self.transforms_imagenet_raw:
if type(transform) is RGB2Lab:
I1_raw = I1
I1 = transform(I1)
for transform in self.transforms_imagenet_raw:
if type(transform) is RGB2Lab:
I2 = self.distortion_transform(I2, forward_dx, forward_dy)
I2_raw = I2
I2 = transform(I2)
# print("i'm here get image 4")
I2[0:1, :, :] = I2[0:1, :, :] + torch.randn(1) * self.brightnessjitter
I_reference_video = self.extra_reference_transform(I_reference_video)
for transform in self.transforms_imagenet_raw:
I_reference_video = transform(I_reference_video)
I_reference_video_real = self.transforms_imagenet(I_reference_video_real)
# print("i'm here get image 5")
flow_forward_raw = np.stack((forward_dy, forward_dx), axis=-1)
flow_forward = self.flow_transform(flow_forward_raw)
# update the mask for the pixels on the border
grid_x, grid_y = np.meshgrid(np.arange(self.image_size[0]), np.arange(self.image_size[1]), indexing="ij")
grid = np.stack((grid_y, grid_x), axis=-1)
grid_warp = grid + flow_forward_raw
location_y = grid_warp[:, :, 0].flatten()
location_x = grid_warp[:, :, 1].flatten()
I2_raw = np.array(I2_raw).astype(float)
I21_r = map_coordinates(I2_raw[:, :, 0], np.stack((location_x, location_y)), cval=-1).reshape(
(self.image_size[0], self.image_size[1])
)
I21_g = map_coordinates(I2_raw[:, :, 1], np.stack((location_x, location_y)), cval=-1).reshape(
(self.image_size[0], self.image_size[1])
)
I21_b = map_coordinates(I2_raw[:, :, 2], np.stack((location_x, location_y)), cval=-1).reshape(
(self.image_size[0], self.image_size[1])
)
I21_raw = np.stack((I21_r, I21_g, I21_b), axis=2)
mask = np.ones((self.image_size[0], self.image_size[1]))
mask[(I21_raw[:, :, 0] == -1) & (I21_raw[:, :, 1] == -1) & (I21_raw[:, :, 2] == -1)] = 0
mask[abs(I21_raw - I1_raw).sum(axis=-1) > 50] = 0
mask = self.ToTensor(mask)
# print("i'm here get image 6")
if np.random.random() < self.real_reference_probability:
I_reference_output = I_reference_video_real
placeholder = torch.zeros_like(I1)
self_ref_flag = torch.zeros_like(I1)
else:
I_reference_output = I_reference_video
placeholder = I2 if np.random.random() < self.nonzero_placeholder_probability else torch.zeros_like(I1)
self_ref_flag = torch.ones_like(I1)
# except Exception as e:
# if combo_path is not None:
# print("problem in ", combo_path)
# print("problem in, ", image_a_path)
# print(e)
# return self.__getitem__(np.random.randint(0, len(self.image_pairs)))
# print("i'm here get image 7")
return [I1, I2, I_reference_output, flow_forward, mask, placeholder, self_ref_flag, "holder", pb, pa]
def __len__(self):
return len(self.image_pairs)