File size: 3,697 Bytes
73ca179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from random import randint

import cv2
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset

from .utils import cvtColor, preprocess_input

def look_image(image_name, image):
    image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    cv2.imshow(image_name, image)
    cv2.waitKey(0)


def get_new_img_size(width, height, img_min_side=600):
    if width <= height:
        f = float(img_min_side) / width
        resized_height = int(f * height)
        resized_width = int(img_min_side)
    else:
        f = float(img_min_side) / height
        resized_width = int(f * width)
        resized_height = int(img_min_side)

    return resized_width, resized_height

class MASKGANDataset(Dataset):
    def __init__(self, train_lines, lr_shape, hr_shape):
        super(MASKGANDataset, self).__init__()

        self.train_lines    = train_lines
        self.train_batches  = len(train_lines)

        self.lr_shape       = lr_shape
        self.hr_shape       = hr_shape

    def __len__(self):
        return self.train_batches

    def __getitem__(self, index):
        index = index % self.train_batches
        image_list = self.train_lines[index].split(' ')
        image_origin = Image.open(image_list[0])
        image_masked = Image.open(image_list[1].split()[0])

        image_origin, image_masked = self.get_random_data(image_origin, image_masked, self.hr_shape)

        image_origin = image_origin.resize((self.hr_shape[1], self.hr_shape[0]), Image.BICUBIC)
        image_masked = image_masked.resize((self.lr_shape[1], self.lr_shape[0]), Image.BICUBIC)
        # look_image('origin', image_origin)
        # look_image('masked', image_masked)
        image_origin = np.transpose(preprocess_input(np.array(image_origin, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])
        image_masked = np.transpose(preprocess_input(np.array(image_masked, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1])

        return np.array(image_masked), np.array(image_origin)

    def rand(self, a=0, b=1):
        return np.random.rand()*(b-a) + a

    def get_random_data(self, image_origin, image_masked, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
        #------------------------------#
        #   读取图像并转换成RGB图像
        #------------------------------#
        image_origin   = cvtColor(image_origin)
        image_masked   = cvtColor(image_masked)

        #------------------------------------------#
        #   色域扭曲
        #------------------------------------------#
        hue = self.rand(-hue, hue)
        sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
        val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)

        x = cv2.cvtColor(np.array(image_origin,np.float32)/255, cv2.COLOR_RGB2HSV)
        x[..., 1] *= sat
        x[..., 2] *= val
        x[x[:,:, 0]>360, 0] = 360
        x[:, :, 1:][x[:, :, 1:]>1] = 1
        x[x<0] = 0
        image_data_origin = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255

        x = cv2.cvtColor(np.array(image_masked,np.float32)/255, cv2.COLOR_RGB2HSV)
        x[..., 1] *= sat
        x[..., 2] *= val
        x[x[:,:, 0]>360, 0] = 360
        x[:, :, 1:][x[:, :, 1:]>1] = 1
        x[x<0] = 0
        image_data_masked = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255

        return Image.fromarray(np.uint8(image_data_origin)), Image.fromarray(np.uint8(image_data_masked))

        
def MASKGAN_dataset_collate(batch):
    images_l = []
    images_h = []
    for img_l, img_h in batch:
        images_l.append(img_l)
        images_h.append(img_h)
    return np.array(images_l), np.array(images_h)