File size: 3,697 Bytes
73ca179 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from random import randint
import cv2
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
from .utils import cvtColor, preprocess_input
def look_image(image_name, image):
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
cv2.imshow(image_name, image)
cv2.waitKey(0)
def get_new_img_size(width, height, img_min_side=600):
if width <= height:
f = float(img_min_side) / width
resized_height = int(f * height)
resized_width = int(img_min_side)
else:
f = float(img_min_side) / height
resized_width = int(f * width)
resized_height = int(img_min_side)
return resized_width, resized_height
class MASKGANDataset(Dataset):
def __init__(self, train_lines, lr_shape, hr_shape):
super(MASKGANDataset, self).__init__()
self.train_lines = train_lines
self.train_batches = len(train_lines)
self.lr_shape = lr_shape
self.hr_shape = hr_shape
def __len__(self):
return self.train_batches
def __getitem__(self, index):
index = index % self.train_batches
image_list = self.train_lines[index].split(' ')
image_origin = Image.open(image_list[0])
image_masked = Image.open(image_list[1].split()[0])
image_origin, image_masked = self.get_random_data(image_origin, image_masked, self.hr_shape)
image_origin = image_origin.resize((self.hr_shape[1], self.hr_shape[0]), Image.BICUBIC)
image_masked = image_masked.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
# look_image('origin', image_origin)
# look_image('masked', image_masked)
image_origin = np.transpose(preprocess_input(np.array(image_origin, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
image_masked = np.transpose(preprocess_input(np.array(image_masked, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
return np.array(image_masked), np.array(image_origin)
def rand(self, a=0, b=1):
return np.random.rand()*(b-a) + a
def get_random_data(self, image_origin, image_masked, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
#------------------------------#
# 读取图像并转换成RGB图像
#------------------------------#
image_origin = cvtColor(image_origin)
image_masked = cvtColor(image_masked)
#------------------------------------------#
# 色域扭曲
#------------------------------------------#
hue = self.rand(-hue, hue)
sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
x = cv2.cvtColor(np.array(image_origin,np.float32)/255, cv2.COLOR_RGB2HSV)
x[..., 1] *= sat
x[..., 2] *= val
x[x[:,:, 0]>360, 0] = 360
x[:, :, 1:][x[:, :, 1:]>1] = 1
x[x<0] = 0
image_data_origin = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
x = cv2.cvtColor(np.array(image_masked,np.float32)/255, cv2.COLOR_RGB2HSV)
x[..., 1] *= sat
x[..., 2] *= val
x[x[:,:, 0]>360, 0] = 360
x[:, :, 1:][x[:, :, 1:]>1] = 1
x[x<0] = 0
image_data_masked = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
return Image.fromarray(np.uint8(image_data_origin)), Image.fromarray(np.uint8(image_data_masked))
def MASKGAN_dataset_collate(batch):
images_l = []
images_h = []
for img_l, img_h in batch:
images_l.append(img_l)
images_h.append(img_h)
return np.array(images_l), np.array(images_h) |