|
from random import randint |
|
|
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
from torch.utils.data.dataset import Dataset |
|
|
|
from .utils import cvtColor, preprocess_input |
|
|
|
def look_image(image_name, image): |
|
image = np.array(image) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
cv2.imshow(image_name, image) |
|
cv2.waitKey(0) |
|
|
|
|
|
def get_new_img_size(width, height, img_min_side=600): |
|
if width <= height: |
|
f = float(img_min_side) / width |
|
resized_height = int(f * height) |
|
resized_width = int(img_min_side) |
|
else: |
|
f = float(img_min_side) / height |
|
resized_width = int(f * width) |
|
resized_height = int(img_min_side) |
|
|
|
return resized_width, resized_height |
|
|
|
class MASKGANDataset(Dataset): |
|
def __init__(self, train_lines, lr_shape, hr_shape): |
|
super(MASKGANDataset, self).__init__() |
|
|
|
self.train_lines = train_lines |
|
self.train_batches = len(train_lines) |
|
|
|
self.lr_shape = lr_shape |
|
self.hr_shape = hr_shape |
|
|
|
def __len__(self): |
|
return self.train_batches |
|
|
|
def __getitem__(self, index): |
|
index = index % self.train_batches |
|
image_list = self.train_lines[index].split(' ') |
|
image_origin = Image.open(image_list[0]) |
|
image_masked = Image.open(image_list[1].split()[0]) |
|
|
|
image_origin, image_masked = self.get_random_data(image_origin, image_masked, self.hr_shape) |
|
|
|
image_origin = image_origin.resize((self.hr_shape[1], self.hr_shape[0]), Image.BICUBIC) |
|
image_masked = image_masked.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC) |
|
|
|
|
|
image_origin = np.transpose(preprocess_input(np.array(image_origin, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]) |
|
image_masked = np.transpose(preprocess_input(np.array(image_masked, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]) |
|
|
|
return np.array(image_masked), np.array(image_origin) |
|
|
|
def rand(self, a=0, b=1): |
|
return np.random.rand()*(b-a) + a |
|
|
|
def get_random_data(self, image_origin, image_masked, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True): |
|
|
|
|
|
|
|
image_origin = cvtColor(image_origin) |
|
image_masked = cvtColor(image_masked) |
|
|
|
|
|
|
|
|
|
hue = self.rand(-hue, hue) |
|
sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat) |
|
val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val) |
|
|
|
x = cv2.cvtColor(np.array(image_origin,np.float32)/255, cv2.COLOR_RGB2HSV) |
|
x[..., 1] *= sat |
|
x[..., 2] *= val |
|
x[x[:,:, 0]>360, 0] = 360 |
|
x[:, :, 1:][x[:, :, 1:]>1] = 1 |
|
x[x<0] = 0 |
|
image_data_origin = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 |
|
|
|
x = cv2.cvtColor(np.array(image_masked,np.float32)/255, cv2.COLOR_RGB2HSV) |
|
x[..., 1] *= sat |
|
x[..., 2] *= val |
|
x[x[:,:, 0]>360, 0] = 360 |
|
x[:, :, 1:][x[:, :, 1:]>1] = 1 |
|
x[x<0] = 0 |
|
image_data_masked = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255 |
|
|
|
return Image.fromarray(np.uint8(image_data_origin)), Image.fromarray(np.uint8(image_data_masked)) |
|
|
|
|
|
def MASKGAN_dataset_collate(batch): |
|
images_l = [] |
|
images_h = [] |
|
for img_l, img_h in batch: |
|
images_l.append(img_l) |
|
images_h.append(img_h) |
|
return np.array(images_l), np.array(images_h) |