MaskGAN / utils /dataloader.py
白鹭先生
init
73ca179
raw
history blame
No virus
3.7 kB
from random import randint
import cv2
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
from .utils import cvtColor, preprocess_input
def look_image(image_name, image):
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
cv2.imshow(image_name, image)
cv2.waitKey(0)
def get_new_img_size(width, height, img_min_side=600):
if width <= height:
f = float(img_min_side) / width
resized_height = int(f * height)
resized_width = int(img_min_side)
else:
f = float(img_min_side) / height
resized_width = int(f * width)
resized_height = int(img_min_side)
return resized_width, resized_height
class MASKGANDataset(Dataset):
def __init__(self, train_lines, lr_shape, hr_shape):
super(MASKGANDataset, self).__init__()
self.train_lines = train_lines
self.train_batches = len(train_lines)
self.lr_shape = lr_shape
self.hr_shape = hr_shape
def __len__(self):
return self.train_batches
def __getitem__(self, index):
index = index % self.train_batches
image_list = self.train_lines[index].split(' ')
image_origin = Image.open(image_list[0])
image_masked = Image.open(image_list[1].split()[0])
image_origin, image_masked = self.get_random_data(image_origin, image_masked, self.hr_shape)
image_origin = image_origin.resize((self.hr_shape[1], self.hr_shape[0]), Image.BICUBIC)
image_masked = image_masked.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
# look_image('origin', image_origin)
# look_image('masked', image_masked)
image_origin = np.transpose(preprocess_input(np.array(image_origin, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
image_masked = np.transpose(preprocess_input(np.array(image_masked, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
return np.array(image_masked), np.array(image_origin)
def rand(self, a=0, b=1):
return np.random.rand()*(b-a) + a
def get_random_data(self, image_origin, image_masked, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
#------------------------------#
# 读取图像并转换成RGB图像
#------------------------------#
image_origin = cvtColor(image_origin)
image_masked = cvtColor(image_masked)
#------------------------------------------#
# 色域扭曲
#------------------------------------------#
hue = self.rand(-hue, hue)
sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
x = cv2.cvtColor(np.array(image_origin,np.float32)/255, cv2.COLOR_RGB2HSV)
x[..., 1] *= sat
x[..., 2] *= val
x[x[:,:, 0]>360, 0] = 360
x[:, :, 1:][x[:, :, 1:]>1] = 1
x[x<0] = 0
image_data_origin = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
x = cv2.cvtColor(np.array(image_masked,np.float32)/255, cv2.COLOR_RGB2HSV)
x[..., 1] *= sat
x[..., 2] *= val
x[x[:,:, 0]>360, 0] = 360
x[:, :, 1:][x[:, :, 1:]>1] = 1
x[x<0] = 0
image_data_masked = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
return Image.fromarray(np.uint8(image_data_origin)), Image.fromarray(np.uint8(image_data_masked))
def MASKGAN_dataset_collate(batch):
images_l = []
images_h = []
for img_l, img_h in batch:
images_l.append(img_l)
images_h.append(img_h)
return np.array(images_l), np.array(images_h)