File size: 4,306 Bytes
533763f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os.path

from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import pandas as pd
from skimage import io
from Utils.Augmentations import Augmentations, Resize


class Datasets(Dataset):
    def __init__(self, data_file, transform=None, phase='train', *args, **kwargs):
        self.transform = transform
        self.data_info = pd.read_csv(data_file, index_col=0)
        self.phase = phase

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, index):
        data = self.pull_item_seg(index)
        return data

    def pull_item_seg(self, index):
        """
        :param index: image index
        """
        data = self.data_info.iloc[index]
        img_name = data['img']
        label_name = data['label']

        ori_img = io.imread(img_name, as_gray=False)
        ori_label = io.imread(label_name, as_gray=True)
        assert (ori_img is not None and ori_label is not None), f'{img_name} or {label_name} is not valid'

        if self.transform is not None:
            img, label = self.transform((ori_img, ori_label))

        one_hot_label = np.zeros([2] + list(label.shape), dtype=np.float)
        one_hot_label[0] = label == 0
        one_hot_label[1] = label > 0
        return_dict = {
            'img': torch.from_numpy(img).permute(2, 0, 1),
            'label': torch.from_numpy(one_hot_label),
            'img_name': os.path.basename(img_name)
        }
        return return_dict


def get_data_loader(config, test_mode=False):
    if not test_mode:
        train_params = {
            'batch_size': config['BATCH_SIZE'],
            'shuffle': config['IS_SHUFFLE'],
            'drop_last': False,
            'collate_fn': collate_fn,
            'num_workers': config['NUM_WORKERS'],
            'pin_memory': False
        }
        #  data_file, config, transform=None
        train_set = Datasets(
            config['DATASET'],
            Augmentations(
                config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'train', config['PHASE'], config
            ),
            config['PHASE'],
            config
        )
        patterns = ['train']
    else:
        patterns = []

    if config['IS_VAL']:
        val_params = {
            'batch_size': config['VAL_BATCH_SIZE'],
            'shuffle': False,
            'drop_last': False,
            'collate_fn': collate_fn,
            'num_workers': config['NUM_WORKERS'],
            'pin_memory': False
        }
        val_set = Datasets(
            config['VAL_DATASET'],
            Augmentations(
                config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'val', config['PHASE'], config
            ),
            config['PHASE'],
            config
        )
        patterns += ['val']

    if config['IS_TEST']:
        test_params = {
            'batch_size': config['VAL_BATCH_SIZE'],
            'shuffle': False,
            'drop_last': False,
            'collate_fn': collate_fn,
            'num_workers': config['NUM_WORKERS'],
            'pin_memory': False
        }
        test_set = Datasets(
            config['TEST_DATASET'],
            Augmentations(
                config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'test', config['PHASE'], config
            ),
            config['PHASE'],
            config
        )
        patterns += ['test']

    data_loaders = {}
    for x in patterns:
        data_loaders[x] = DataLoader(eval(x+'_set'), **eval(x+'_params'))
    return data_loaders


def collate_fn(batch):
    def to_tensor(item):
        if torch.is_tensor(item):
            return item
        elif isinstance(item, type(np.array(0))):
            return torch.from_numpy(item).float()
        elif isinstance(item, type('0')):
            return item
        elif isinstance(item, list):
            return item
        elif isinstance(item, dict):
            return item

    return_data = {}
    for key in batch[0].keys():
        return_data[key] = []

    for sample in batch:
        for key, value in sample.items():
            return_data[key].append(to_tensor(value))

    keys = set(batch[0].keys()) - {'img_name'}
    for key in keys:
        return_data[key] = torch.stack(return_data[key], dim=0)

    return return_data