InPeerReview commited on
Commit
8cf4db8
·
verified ·
1 Parent(s): e15bd17

Upload 5 files

Browse files
Files changed (5) hide show
  1. data/CDDataset.py +168 -0
  2. data/__init__.py +68 -0
  3. data/colormap.py +3 -0
  4. data/generate_list.py +12 -0
  5. data/util.py +8 -0
data/CDDataset.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CD Dataset
3
+ """
4
+ import os
5
+ from PIL import Image
6
+ import numpy as np
7
+ from torch.utils import data
8
+ import data.util as Util
9
+ from torch.utils.data import Dataset
10
+ import torchvision
11
+ import torch
12
+
13
+ totensor = torchvision.transforms.ToTensor()
14
+
15
+ """
16
+ CD Dataset
17
+ ├─image
18
+ ├─image_post
19
+ ├─label
20
+ └─list
21
+ """
22
+
23
+ IMG_FOLDER_NAME = 'A'
24
+ IMG_POST_FOLDER_NAME = 'B'
25
+ LABEL_FOLDER_NAME = 'label'
26
+ LABEL1_FOLDER_NAME = 'label1'
27
+ LABEL2_FOLDER_NAME = 'label2'
28
+ LIST_FOLDER_NAME = 'list'
29
+
30
+ label_suffix = ".png"
31
+
32
+ #list内存放image_name 构建读取图片名字函数
33
+ def load_img_name_list(dataset_path):
34
+ img_name_list = np.loadtxt(dataset_path, dtype=np.str_)
35
+ if img_name_list.ndim == 2:
36
+ return img_name_list[:, 0]
37
+ return img_name_list
38
+
39
+ #获取各个文件夹的路径
40
+ def get_img_path(root_dir, img_name):
41
+ return os.path.join(root_dir, IMG_FOLDER_NAME, img_name)
42
+
43
+ def get_img_post_path(root_dir, img_name):
44
+ return os.path.join(root_dir, IMG_POST_FOLDER_NAME, img_name)
45
+
46
+ def get_label_path(root_dir, img_name):
47
+ return os.path.join(root_dir, LABEL_FOLDER_NAME, img_name)
48
+
49
+ def get_label1_path(root_dir, img_name):
50
+ return os.path.join(root_dir, LABEL1_FOLDER_NAME, img_name)
51
+
52
+ def get_label2_path(root_dir, img_name):
53
+ return os.path.join(root_dir, LABEL2_FOLDER_NAME, img_name)
54
+
55
+ class CDDataset(Dataset):
56
+ def __init__(self, root_dir, resolution=256, split='train', data_len=-1, label_transform=None):
57
+
58
+ self.root_dir = root_dir
59
+ self.resolution = resolution
60
+ self.data_len = data_len
61
+ self.split = split #train / val / test
62
+ self.label_transform = label_transform
63
+
64
+ self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split + '.txt')
65
+
66
+ self.img_name_list = load_img_name_list(self.list_path)
67
+
68
+ self.dataset_len = len(self.img_name_list)
69
+
70
+ if self.data_len <= 0:
71
+ self.data_len = self.dataset_len
72
+ else:
73
+ self.data_len = min(self.dataset_len, self.data_len)
74
+
75
+ def __len__(self):
76
+ return self.data_len
77
+
78
+ def __getitem__(self, index):
79
+ A_path = get_img_path(self.root_dir, self.img_name_list[index % self.data_len])
80
+ B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.data_len])
81
+
82
+ img_A = Image.open(A_path).convert('RGB')
83
+ img_B = Image.open(B_path).convert('RGB')
84
+
85
+ L_path = get_label_path(self.root_dir, self.img_name_list[index % self.data_len])
86
+ img_label = Image.open(L_path).convert("RGB")
87
+
88
+ img_A = Util.transform_augment_cd(img_A, min_max=(-1, 1))
89
+ img_B = Util.transform_augment_cd(img_B, min_max=(-1, 1))
90
+ img_label = Util.transform_augment_cd(img_label, min_max=(0, 1))
91
+ if img_label.dim() > 2:
92
+ img_label = img_label[0]
93
+
94
+ return {'A':img_A, 'B':img_B, 'L':img_label, 'Index':index}
95
+
96
+
97
+
98
+ class SCDDataset(Dataset):
99
+ def __init__(self, root_dir, resolution=512, split='train', data_len=-1, label_transform=None):
100
+
101
+ self.root_dir = root_dir
102
+ self.resolution = resolution
103
+ self.data_len = data_len
104
+ self.split = split #train / val / test
105
+ self.label_transform = label_transform
106
+
107
+ self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split + '.txt')
108
+
109
+ self.img_name_list = load_img_name_list(self.list_path)
110
+
111
+ self.dataset_len = len(self.img_name_list)
112
+
113
+ if self.data_len <= 0:
114
+ self.data_len = self.dataset_len
115
+ else:
116
+ self.data_len = min(self.dataset_len, self.data_len)
117
+
118
+ def __len__(self):
119
+ return self.data_len
120
+
121
+ def __getitem__(self, index):
122
+ A_path = get_img_path(self.root_dir, self.img_name_list[index % self.data_len])
123
+ B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.data_len])
124
+ name = A_path.split('\\')[-1].split('.')[0]
125
+ img_A = Image.open(A_path).convert('RGB')
126
+ img_B = Image.open(B_path).convert('RGB')
127
+
128
+ L_path = get_label_path(self.root_dir, self.img_name_list[index % self.data_len])
129
+ L1_path = get_label1_path(self.root_dir, self.img_name_list[index % self.data_len])
130
+ L2_path = get_label2_path(self.root_dir, self.img_name_list[index % self.data_len])
131
+ img_label = np.array(Image.open(L_path), dtype=np.uint8)
132
+ img_label1 = np.array(Image.open(L1_path), dtype=np.uint8)
133
+ img_label2 = np.array(Image.open(L2_path), dtype=np.uint8)
134
+
135
+ img_A = Util.transform_augment_cd(img_A, min_max=(-1, 1))
136
+ img_B = Util.transform_augment_cd(img_B, min_max=(-1, 1))
137
+ img_label = torch.from_numpy(img_label)
138
+ img_label1 = torch.from_numpy(img_label1)
139
+ # add cls label on label1
140
+ cls_category1 = torch.unique(img_label1)
141
+ cls_label1 = torch.zeros(7, dtype = int)
142
+ for index in cls_category1:
143
+ cls_label1[int(index)] = 1
144
+
145
+ img_label2 = torch.from_numpy(img_label2)
146
+ # add cls label on label2
147
+ cls_category2 = torch.unique(img_label2)
148
+ cls_label2 = torch.zeros(7, dtype=int)
149
+ for index in cls_category2:
150
+ cls_label2[int(index)] = 1
151
+
152
+ if img_label.dim() > 2:
153
+ img_label = img_label[0]
154
+ img_label1 = img_label1[0]
155
+ img_label2 = img_label2[0]
156
+
157
+ return {'A':img_A, 'B':img_B, 'L':img_label, 'L1':img_label1, 'L2':img_label2,
158
+ 'Index':index, 'name':name, 'cls1':cls_label1, 'cls2':cls_label2}
159
+
160
+ if __name__ == '__main__':
161
+ root_dir = r'E:\cddataset\mmcd\Second_my'
162
+ cddata = SCDDataset(root_dir=root_dir)
163
+ list_path = os.path.join(root_dir, 'list', 'val', '.txt')
164
+ for i in range(593):
165
+ cls_labe1 = cddata.__getitem__(i)['cls1']
166
+ print(cls_labe1)
167
+ cls_labe2 = cddata.__getitem__(i)['cls2']
168
+ print(cls_labe2)
data/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import core.logger as Logger
4
+ import data as Data
5
+
6
+ #Create chaneg detection dataset
7
+ import logging
8
+ import torch.utils.data
9
+
10
+ def create_cd_dataloader(dataset, dataset_opt, phase):
11
+ if phase == 'train' or 'val' or 'test':
12
+ return torch.utils.data.DataLoader(
13
+ dataset,
14
+ batch_size=dataset_opt['batch_size'],
15
+ shuffle=dataset_opt['use_shuffle'],
16
+ num_workers=dataset_opt['num_workers'],
17
+ pin_memory=True)
18
+ else:
19
+ raise NotImplementedError(
20
+ 'Dataloader [{:s}] is not found'.format(phase)
21
+ )
22
+
23
+ def create_cd_dataset(dataset_opt, phase):
24
+ from data.CDDataset import CDDataset
25
+ print(dataset_opt["datasetroot"])
26
+ dataset = CDDataset(root_dir=dataset_opt["datasetroot"],
27
+ resolution=dataset_opt["resolution"],
28
+ split=phase,
29
+ data_len=dataset_opt["data_len"]
30
+ )
31
+ logger = logging.getLogger('base')
32
+ logger.info('Dataset [{:s} - {:s} - {:s}] is created'.format(dataset.__class__.__name__,
33
+ dataset_opt['name'],
34
+ phase))
35
+ return dataset
36
+
37
+ def create_scd_dataset(dataset_opt, phase):
38
+ from data.CDDataset import SCDDataset
39
+ print(dataset_opt["datasetroot"])
40
+ dataset = SCDDataset(root_dir=dataset_opt["datasetroot"],
41
+ resolution=dataset_opt["resolution"],
42
+ split=phase,
43
+ data_len=dataset_opt["data_len"]
44
+ )
45
+ logger = logging.getLogger('base')
46
+ logger.info('Dataset [{:s} - {:s} - {:s}] is created'.format(dataset.__class__.__name__,
47
+ dataset_opt['name'],
48
+ phase))
49
+ return dataset
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument('-c', '--config', type=str, default='../config/levir.json')
55
+ parser.add_argument('-p', '--phase', type=str, choices=['train', 'test'], default='train')
56
+ parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
57
+
58
+ args = parser.parse_args()
59
+ opt = Logger.parse(args)
60
+ opt = Logger.dict_to_nonedict(opt)
61
+ print(opt)
62
+
63
+ for phase, dataset_opt in opt['datasets'].items():
64
+ if phase == 'train' and args.phase != 'test':
65
+ print("Creating [train] change-detection dataloader.")
66
+ train_set = Data.create_cd_dataset(dataset_opt, phase)
67
+ train_loader = Data.create_cd_dataloader(train_set, dataset_opt, phase)
68
+
data/colormap.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+
3
+ second_colormap = [[255, 255, 255], [0, 0, 255], [128, 128, 128], [0, 128, 0], [0, 255, 0], [128, 0, 0], [255, 0, 0]]
data/generate_list.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def genreate_list(root, split):
4
+
5
+ list_path = os.path.join(root, split+'.txt')
6
+ with open(list_path, 'w') as f:
7
+ for img_name in os.listdir(os.path.join(root)):
8
+ f.write(img_name + '\n')
9
+
10
+ if __name__ == "__main__":
11
+ root = r'E:\cddataset\mmcd\Second_my\val\im1'
12
+ genreate_list(root, 'val')
data/util.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+
3
+ totensor = torchvision.transforms.ToTensor()
4
+
5
+ def transform_augment_cd(img, min_max=(0, 1)):
6
+ img = totensor(img)
7
+ ret_img = img * (min_max[1] - min_max[0]) + min_max[0]
8
+ return ret_img