File size: 5,178 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

import os
import numpy as np
from PIL import Image
from glob import glob
from skimage.segmentation import slic
import torchvision.transforms.functional as TF
from scipy.io import loadmat
import random
import cv2

import sys
sys.path.append("../..")

def label2one_hot_torch(labels, C=14):
    """ Converts an integer label torch.autograd.Variable to a one-hot Variable.

    Args:
      labels(tensor) : segmentation label
      C (integer) : number of classes in labels

    Returns:
      target (tensor) : one-hot vector of the input label

    Shape:
      labels: (B, 1, H, W)
      target: (B, N, H, W)
    """
    b,_, h, w = labels.shape
    one_hot = torch.zeros(b, C, h, w, dtype=torch.long).to(labels)
    target = one_hot.scatter_(1, labels.type(torch.long).data, 1) #require long type

    return target.type(torch.float32)

class Dataset(data.Dataset):
    def __init__(self, data_dir, crop_size = 128, test=False,
                 sp_num = 256, slic = True, preprocess_name = False,
                 gt_label = False, label_path = None, test_time = False,
                 img_path = None):
        super(Dataset, self).__init__()
        ext = ["*.jpg"]
        dl = []
        self.test = test
        self.test_time = test_time

        [dl.extend(glob(data_dir + '/**/' + e, recursive=True)) for e in ext]

        data_list = sorted(dl)
        self.data_list = data_list
        self.sp_num = sp_num
        self.slic = slic

        self.crop = transforms.CenterCrop(size = (crop_size, crop_size))
        self.crop_size = crop_size
        self.test = test

        self.gt_label = gt_label
        if gt_label:
            self.label_path = label_path

        self.img_path = img_path

    def preprocess_label(self, seg):
        segs = label2one_hot_torch(seg.unsqueeze(0), C = seg.max() + 1)
        new_seg = []
        for cnt in range(seg.max() + 1):
            if segs[0, cnt].sum() > 0:
                new_seg.append(segs[0, cnt])
        new_seg = torch.stack(new_seg)
        return torch.argmax(new_seg, dim = 0)

    def __getitem__(self, index):
        if self.img_path is None:
            data_path = self.data_list[index]
        else:
            data_path = self.img_path
        rgb_img = Image.open(data_path)
        imgH, imgW = rgb_img.size

        if self.gt_label:
            img_name = data_path.split("/")[-1].split("_")[0]
            mat_path = os.path.join(self.label_path, data_path.split('/')[-2], img_name.replace('.jpg', '.mat'))
            mat = loadmat(mat_path)
            max_label_num = 0
            final_seg = None
            for i in range(len(mat['groundTruth'][0])):
                seg = mat['groundTruth'][0][i][0][0][0]
                if len(np.unique(seg)) > max_label_num:
                    max_label_num = len(np.unique(seg))
                    final_seg = seg
            seg = torch.from_numpy(final_seg.astype(np.float32))
            segs = seg.long().unsqueeze(0)

        if self.img_path is None:
            i, j, h, w = transforms.RandomCrop.get_params(rgb_img, output_size=(self.crop_size, self.crop_size))
        else:
            i = 40; j = 40; h = self.crop_size; w = self.crop_size
        rgb_img = TF.crop(rgb_img, i, j, h, w)
        if self.gt_label:
            segs = TF.crop(segs, i, j, h, w)
            segs = self.preprocess_label(segs)

        if self.slic:
            sp_num = self.sp_num
            # compute superpixel
            slic_i = slic(np.array(rgb_img), n_segments=sp_num, compactness=10, start_label=0, min_size_factor=0.3)
            slic_i = torch.from_numpy(slic_i)
            slic_i[slic_i >= sp_num] = sp_num - 1
            oh = label2one_hot_torch(slic_i.unsqueeze(0).unsqueeze(0), C = sp_num).squeeze()

        rgb_img = TF.to_tensor(rgb_img)
        if rgb_img.shape[0] == 1:
            rgb_img = rgb_img.repeat(3, 1, 1)
        rgb_img = rgb_img[:3, :, :]

        rets = []
        rets.append(rgb_img)
        if self.slic:
            rets.append(oh)
        rets.append(data_path.split("/")[-1])
        rets.append(index)
        if self.gt_label:
            rets.append(segs.view(1, segs.shape[-2], segs.shape[-1]))
        return rets

    def __len__(self):
        return len(self.data_list)

if __name__ == '__main__':
    import torchvision.utils as vutils
    dataset = Dataset('/home/xtli/DATA/texture_data/',
                      sampled_num=3000)
    loader_ = torch.utils.data.DataLoader(dataset     = dataset,
                                         batch_size  = 1,
                                         shuffle     = True,
                                         num_workers = 1,
                                         drop_last   = True)
    loader = iter(loader_)
    img, points, pixs = loader.next()

    crop_size = 128
    canvas = torch.zeros((1, 3, crop_size, crop_size))
    for i in range(points.shape[-2]):
        p = (points[0, i] + 1) / 2.0 * (crop_size - 1)
        canvas[0, :, int(p[0]), int(p[1])] = pixs[0, :, i]
    vutils.save_image(canvas, 'canvas.png')
    vutils.save_image(img, 'img.png')