| | """
|
| | CD Dataset
|
| | """
|
| | import os
|
| | from PIL import Image
|
| | import numpy as np
|
| | from torch.utils import data
|
| | import data.util as Util
|
| | from torch.utils.data import Dataset
|
| | import torchvision
|
| | import torch
|
| |
|
| | totensor = torchvision.transforms.ToTensor()
|
| |
|
| | """
|
| | CD Dataset
|
| | ├─image
|
| | ├─image_post
|
| | ├─label
|
| | └─list
|
| | """
|
| |
|
| | IMG_FOLDER_NAME = 'A'
|
| | IMG_POST_FOLDER_NAME = 'B'
|
| | LABEL_FOLDER_NAME = 'label'
|
| | LABEL1_FOLDER_NAME = 'label1'
|
| | LABEL2_FOLDER_NAME = 'label2'
|
| | LIST_FOLDER_NAME = 'list'
|
| |
|
| | label_suffix = ".png"
|
| |
|
| |
|
| | def load_img_name_list(dataset_path):
|
| | img_name_list = np.loadtxt(dataset_path, dtype=np.str_)
|
| | if img_name_list.ndim == 2:
|
| | return img_name_list[:, 0]
|
| | return img_name_list
|
| |
|
| |
|
| | def get_img_path(root_dir, img_name):
|
| | return os.path.join(root_dir, IMG_FOLDER_NAME, img_name)
|
| |
|
| | def get_img_post_path(root_dir, img_name):
|
| | return os.path.join(root_dir, IMG_POST_FOLDER_NAME, img_name)
|
| |
|
| | def get_label_path(root_dir, img_name):
|
| | return os.path.join(root_dir, LABEL_FOLDER_NAME, img_name)
|
| |
|
| | def get_label1_path(root_dir, img_name):
|
| | return os.path.join(root_dir, LABEL1_FOLDER_NAME, img_name)
|
| |
|
| | def get_label2_path(root_dir, img_name):
|
| | return os.path.join(root_dir, LABEL2_FOLDER_NAME, img_name)
|
| |
|
| | class CDDataset(Dataset):
|
| | def __init__(self, root_dir, resolution=256, split='train', data_len=-1, label_transform=None):
|
| |
|
| | self.root_dir = root_dir
|
| | self.resolution = resolution
|
| | self.data_len = data_len
|
| | self.split = split
|
| | self.label_transform = label_transform
|
| |
|
| | self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split + '.txt')
|
| |
|
| | self.img_name_list = load_img_name_list(self.list_path)
|
| |
|
| | self.dataset_len = len(self.img_name_list)
|
| |
|
| | if self.data_len <= 0:
|
| | self.data_len = self.dataset_len
|
| | else:
|
| | self.data_len = min(self.dataset_len, self.data_len)
|
| |
|
| | def __len__(self):
|
| | return self.data_len
|
| |
|
| | def __getitem__(self, index):
|
| | A_path = get_img_path(self.root_dir, self.img_name_list[index % self.data_len])
|
| | B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.data_len])
|
| |
|
| | img_A = Image.open(A_path).convert('RGB')
|
| | img_B = Image.open(B_path).convert('RGB')
|
| |
|
| | L_path = get_label_path(self.root_dir, self.img_name_list[index % self.data_len])
|
| | img_label = Image.open(L_path).convert("RGB")
|
| |
|
| | img_A = Util.transform_augment_cd(img_A, min_max=(-1, 1))
|
| | img_B = Util.transform_augment_cd(img_B, min_max=(-1, 1))
|
| | img_label = Util.transform_augment_cd(img_label, min_max=(0, 1))
|
| | if img_label.dim() > 2:
|
| | img_label = img_label[0]
|
| |
|
| | return {'A':img_A, 'B':img_B, 'L':img_label, 'Index':index}
|
| |
|
| |
|
| |
|
| | class SCDDataset(Dataset):
|
| | def __init__(self, root_dir, resolution=512, split='train', data_len=-1, label_transform=None):
|
| |
|
| | self.root_dir = root_dir
|
| | self.resolution = resolution
|
| | self.data_len = data_len
|
| | self.split = split
|
| | self.label_transform = label_transform
|
| |
|
| | self.list_path = os.path.join(self.root_dir, LIST_FOLDER_NAME, self.split + '.txt')
|
| |
|
| | self.img_name_list = load_img_name_list(self.list_path)
|
| |
|
| | self.dataset_len = len(self.img_name_list)
|
| |
|
| | if self.data_len <= 0:
|
| | self.data_len = self.dataset_len
|
| | else:
|
| | self.data_len = min(self.dataset_len, self.data_len)
|
| |
|
| | def __len__(self):
|
| | return self.data_len
|
| |
|
| | def __getitem__(self, index):
|
| | A_path = get_img_path(self.root_dir, self.img_name_list[index % self.data_len])
|
| | B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.data_len])
|
| | name = A_path.split('\\')[-1].split('.')[0]
|
| | img_A = Image.open(A_path).convert('RGB')
|
| | img_B = Image.open(B_path).convert('RGB')
|
| |
|
| | L_path = get_label_path(self.root_dir, self.img_name_list[index % self.data_len])
|
| | L1_path = get_label1_path(self.root_dir, self.img_name_list[index % self.data_len])
|
| | L2_path = get_label2_path(self.root_dir, self.img_name_list[index % self.data_len])
|
| | img_label = np.array(Image.open(L_path), dtype=np.uint8)
|
| | img_label1 = np.array(Image.open(L1_path), dtype=np.uint8)
|
| | img_label2 = np.array(Image.open(L2_path), dtype=np.uint8)
|
| |
|
| | img_A = Util.transform_augment_cd(img_A, min_max=(-1, 1))
|
| | img_B = Util.transform_augment_cd(img_B, min_max=(-1, 1))
|
| | img_label = torch.from_numpy(img_label)
|
| | img_label1 = torch.from_numpy(img_label1)
|
| |
|
| | cls_category1 = torch.unique(img_label1)
|
| | cls_label1 = torch.zeros(7, dtype = int)
|
| | for index in cls_category1:
|
| | cls_label1[int(index)] = 1
|
| |
|
| | img_label2 = torch.from_numpy(img_label2)
|
| |
|
| | cls_category2 = torch.unique(img_label2)
|
| | cls_label2 = torch.zeros(7, dtype=int)
|
| | for index in cls_category2:
|
| | cls_label2[int(index)] = 1
|
| |
|
| | if img_label.dim() > 2:
|
| | img_label = img_label[0]
|
| | img_label1 = img_label1[0]
|
| | img_label2 = img_label2[0]
|
| |
|
| | return {'A':img_A, 'B':img_B, 'L':img_label, 'L1':img_label1, 'L2':img_label2,
|
| | 'Index':index, 'name':name, 'cls1':cls_label1, 'cls2':cls_label2}
|
| |
|
| | if __name__ == '__main__':
|
| | root_dir = r'E:\cddataset\mmcd\Second_my'
|
| | cddata = SCDDataset(root_dir=root_dir)
|
| | list_path = os.path.join(root_dir, 'list', 'val', '.txt')
|
| | for i in range(593):
|
| | cls_labe1 = cddata.__getitem__(i)['cls1']
|
| | print(cls_labe1)
|
| | cls_labe2 = cddata.__getitem__(i)['cls2']
|
| | print(cls_labe2)
|
| |
|