import os import random from os.path import join import numpy as np import torch.multiprocessing from PIL import Image from scipy.io import loadmat from torch.utils.data import DataLoader from torch.utils.data import Dataset from torchvision.datasets.cityscapes import Cityscapes from torchvision.transforms.functional import to_pil_image from tqdm import tqdm def bit_get(val, idx): """Gets the bit value. Args: val: Input value, int or numpy int array. idx: Which bit of the input val. Returns: The "idx"-th bit of input val. """ return (val >> idx) & 1 def create_pascal_label_colormap(): """Creates a label colormap used in PASCAL VOC segmentation benchmark. Returns: A colormap for visualizing segmentation results. """ colormap = np.zeros((512, 3), dtype=int) ind = np.arange(512, dtype=int) for shift in reversed(list(range(8))): for channel in range(3): colormap[:, channel] |= bit_get(ind, channel) << shift ind >>= 3 return colormap def create_cityscapes_colormap(): colors = [(128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 0)] return np.array(colors) class DirectoryDataset(Dataset): def __init__(self, root, path, image_set, transform, target_transform): super(DirectoryDataset, self).__init__() self.split = image_set self.dir = join(root, path) self.img_dir = join(self.dir, "imgs", self.split) self.label_dir = join(self.dir, "labels", self.split) self.transform = transform self.target_transform = target_transform self.img_files = np.array(sorted(os.listdir(self.img_dir))) assert len(self.img_files) > 0 if os.path.exists(join(self.dir, "labels")): self.label_files = np.array(sorted(os.listdir(self.label_dir))) assert len(self.img_files) == len(self.label_files) else: self.label_files = None self.fine_to_coarse = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: -1, } def __getitem__(self, index): image_fn = self.img_files[index] img = Image.open(join(self.img_dir, image_fn)) if self.label_files is not None: label_fn = self.label_files[index] label = Image.open(join(self.label_dir, label_fn)) seed = np.random.randint(2147483647) random.seed(seed) torch.manual_seed(seed) img = self.transform(img) if self.label_files is not None: random.seed(seed) torch.manual_seed(seed) label = self.target_transform(label) new_label_map = torch.zeros_like(label) for fine, coarse in self.fine_to_coarse.items(): new_label_map[label == fine] = coarse label = new_label_map else: label = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int64) - 1 mask = (label > 0).to(torch.float32) return img, label, mask def __len__(self): return len(self.img_files) class Potsdam(Dataset): def __init__(self, root, image_set, transform, target_transform, coarse_labels): super(Potsdam, self).__init__() self.split = image_set self.root = os.path.join(root, "potsdam") self.transform = transform self.target_transform = target_transform split_files = { "train": ["labelled_train.txt"], "unlabelled_train": ["unlabelled_train.txt"], # "train": ["unlabelled_train.txt"], "val": ["labelled_test.txt"], "train+val": ["labelled_train.txt", "labelled_test.txt"], "all": ["all.txt"] } assert self.split in split_files.keys() self.files = [] for split_file in split_files[self.split]: with open(join(self.root, split_file), "r") as f: self.files.extend(fn.rstrip() for fn in f.readlines()) self.coarse_labels = coarse_labels self.fine_to_coarse = {0: 0, 4: 0, # roads and cars 1: 1, 5: 1, # buildings and clutter 2: 2, 3: 2, # vegetation and trees 255: -1 } def __getitem__(self, index): image_id = self.files[index] img = loadmat(join(self.root, "imgs", image_id + ".mat"))["img"] img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back try: label = loadmat(join(self.root, "gt", image_id + ".mat"))["gt"] label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1)) except FileNotFoundError: label = to_pil_image(torch.ones(1, img.height, img.width)) seed = np.random.randint(2147483647) random.seed(seed) torch.manual_seed(seed) img = self.transform(img) random.seed(seed) torch.manual_seed(seed) label = self.target_transform(label).squeeze(0) if self.coarse_labels: new_label_map = torch.zeros_like(label) for fine, coarse in self.fine_to_coarse.items(): new_label_map[label == fine] = coarse label = new_label_map mask = (label > 0).to(torch.float32) return img, label, mask def __len__(self): return len(self.files) class PotsdamRaw(Dataset): def __init__(self, root, image_set, transform, target_transform, coarse_labels): super(PotsdamRaw, self).__init__() self.split = image_set self.root = os.path.join(root, "potsdamraw", "processed") self.transform = transform self.target_transform = target_transform self.files = [] for im_num in range(38): for i_h in range(15): for i_w in range(15): self.files.append("{}_{}_{}.mat".format(im_num, i_h, i_w)) self.coarse_labels = coarse_labels self.fine_to_coarse = {0: 0, 4: 0, # roads and cars 1: 1, 5: 1, # buildings and clutter 2: 2, 3: 2, # vegetation and trees 255: -1 } def __getitem__(self, index): image_id = self.files[index] img = loadmat(join(self.root, "imgs", image_id))["img"] img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back try: label = loadmat(join(self.root, "gt", image_id))["gt"] label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1)) except FileNotFoundError: label = to_pil_image(torch.ones(1, img.height, img.width)) seed = np.random.randint(2147483647) random.seed(seed) torch.manual_seed(seed) img = self.transform(img) random.seed(seed) torch.manual_seed(seed) label = self.target_transform(label).squeeze(0) if self.coarse_labels: new_label_map = torch.zeros_like(label) for fine, coarse in self.fine_to_coarse.items(): new_label_map[label == fine] = coarse label = new_label_map mask = (label > 0).to(torch.float32) return img, label, mask def __len__(self): return len(self.files) class Coco(Dataset): def __init__(self, root, image_set, transform, target_transform, coarse_labels, exclude_things, subset=None): super(Coco, self).__init__() self.split = image_set self.root = join(root, "cocostuff") self.coarse_labels = coarse_labels self.transform = transform self.label_transform = target_transform self.subset = subset self.exclude_things = exclude_things if self.subset is None: self.image_list = "Coco164kFull_Stuff_Coarse.txt" elif self.subset == 6: # IIC Coarse self.image_list = "Coco164kFew_Stuff_6.txt" elif self.subset == 7: # IIC Fine self.image_list = "Coco164kFull_Stuff_Coarse_7.txt" assert self.split in ["train", "val", "train+val"] split_dirs = { "train": ["train2017"], "val": ["val2017"], "train+val": ["train2017", "val2017"] } self.image_files = [] self.label_files = [] for split_dir in split_dirs[self.split]: with open(join(self.root, "curated", split_dir, self.image_list), "r") as f: img_ids = [fn.rstrip() for fn in f.readlines()] for img_id in img_ids: self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg")) self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png")) self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8, 13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7, 25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10, 37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5, 49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2, 61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0, 73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4, 85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22, 97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15, 107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13, 117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24, 127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17, 137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21, 147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23, 157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17, 167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18, 177: 26, 178: 26, 179: 19, 180: 19, 181: 24} self._label_names = [ "ground-stuff", "plant-stuff", "sky-stuff", ] self.cocostuff3_coarse_classes = [23, 22, 21] self.first_stuff_index = 12 def __getitem__(self, index): image_path = self.image_files[index] label_path = self.label_files[index] seed = np.random.randint(2147483647) random.seed(seed) torch.manual_seed(seed) img = self.transform(Image.open(image_path).convert("RGB")) random.seed(seed) torch.manual_seed(seed) label = self.label_transform(Image.open(label_path)).squeeze(0) label[label == 255] = -1 # to be consistent with 10k coarse_label = torch.zeros_like(label) for fine, coarse in self.fine_to_coarse.items(): coarse_label[label == fine] = coarse coarse_label[label == -1] = -1 if self.coarse_labels: coarser_labels = -torch.ones_like(label) for i, c in enumerate(self.cocostuff3_coarse_classes): coarser_labels[coarse_label == c] = i return img, coarser_labels, coarser_labels >= 0 else: if self.exclude_things: return img, coarse_label - self.first_stuff_index, (coarse_label >= self.first_stuff_index) else: return img, coarse_label, coarse_label >= 0 def __len__(self): return len(self.image_files) class CityscapesSeg(Dataset): def __init__(self, root, image_set, transform, target_transform): super(CityscapesSeg, self).__init__() self.split = image_set self.root = join(root, "cityscapes") if image_set == "train": # our_image_set = "train_extra" # mode = "coarse" our_image_set = "train" mode = "fine" else: our_image_set = image_set mode = "fine" self.inner_loader = Cityscapes(self.root, our_image_set, mode=mode, target_type="semantic", transform=None, target_transform=None) self.transform = transform self.target_transform = target_transform self.first_nonvoid = 7 def __getitem__(self, index): if self.transform is not None: image, target = self.inner_loader[index] seed = np.random.randint(2147483647) random.seed(seed) torch.manual_seed(seed) image = self.transform(image) random.seed(seed) torch.manual_seed(seed) target = self.target_transform(target) target = target - self.first_nonvoid target[target < 0] = -1 mask = target == -1 return image, target.squeeze(0), mask else: return self.inner_loader[index] def __len__(self): return len(self.inner_loader) class CroppedDataset(Dataset): def __init__(self, root, dataset_name, crop_type, crop_ratio, image_set, transform, target_transform): super(CroppedDataset, self).__init__() self.dataset_name = dataset_name self.split = image_set self.root = join(root, "cropped", "{}_{}_crop_{}".format(dataset_name, crop_type, crop_ratio)) self.transform = transform self.target_transform = target_transform self.img_dir = join(self.root, "img", self.split) self.label_dir = join(self.root, "label", self.split) self.num_images = len(os.listdir(self.img_dir)) assert self.num_images == len(os.listdir(self.label_dir)) def __getitem__(self, index): image = Image.open(join(self.img_dir, "{}.jpg".format(index))).convert('RGB') target = Image.open(join(self.label_dir, "{}.png".format(index))) seed = np.random.randint(2147483647) random.seed(seed) torch.manual_seed(seed) image = self.transform(image) random.seed(seed) torch.manual_seed(seed) target = self.target_transform(target) target = target - 1 mask = target == -1 return image, target.squeeze(0), mask def __len__(self): return self.num_images class MaterializedDataset(Dataset): def __init__(self, ds): self.ds = ds self.materialized = [] loader = DataLoader(ds, num_workers=12, collate_fn=lambda l: l[0]) for batch in tqdm(loader): self.materialized.append(batch) def __len__(self): return len(self.ds) def __getitem__(self, ind): return self.materialized[ind] class ContrastiveSegDataset(Dataset): def __init__(self, pytorch_data_dir, dataset_name, crop_type, image_set, transform, target_transform, cfg, aug_geometric_transform=None, aug_photometric_transform=None, num_neighbors=5, compute_knns=False, mask=False, pos_labels=False, pos_images=False, extra_transform=None, model_type_override=None ): super(ContrastiveSegDataset).__init__() self.num_neighbors = num_neighbors self.image_set = image_set self.dataset_name = dataset_name self.mask = mask self.pos_labels = pos_labels self.pos_images = pos_images self.extra_transform = extra_transform if dataset_name == "potsdam": self.n_classes = 3 dataset_class = Potsdam extra_args = dict(coarse_labels=True) elif dataset_name == "potsdamraw": self.n_classes = 3 dataset_class = PotsdamRaw extra_args = dict(coarse_labels=True) elif dataset_name == "directory": self.n_classes = cfg.dir_dataset_n_classes dataset_class = DirectoryDataset extra_args = dict(path=cfg.dir_dataset_name) elif dataset_name == "cityscapes" and crop_type is None: self.n_classes = 27 dataset_class = CityscapesSeg extra_args = dict() elif dataset_name == "cityscapes" and crop_type is not None: self.n_classes = 27 dataset_class = CroppedDataset extra_args = dict(dataset_name="cityscapes", crop_type=crop_type, crop_ratio=cfg.crop_ratio) elif dataset_name == "cocostuff3": self.n_classes = 3 dataset_class = Coco extra_args = dict(coarse_labels=True, subset=6, exclude_things=True) elif dataset_name == "cocostuff15": self.n_classes = 15 dataset_class = Coco extra_args = dict(coarse_labels=False, subset=7, exclude_things=True) elif dataset_name == "cocostuff27" and crop_type is not None: self.n_classes = 27 dataset_class = CroppedDataset extra_args = dict(dataset_name="cocostuff27", crop_type=cfg.crop_type, crop_ratio=cfg.crop_ratio) elif dataset_name == "cocostuff27" and crop_type is None: self.n_classes = 27 dataset_class = Coco extra_args = dict(coarse_labels=False, subset=None, exclude_things=False) if image_set == "val": extra_args["subset"] = 7 else: raise ValueError("Unknown dataset: {}".format(dataset_name)) self.aug_geometric_transform = aug_geometric_transform self.aug_photometric_transform = aug_photometric_transform self.dataset = dataset_class( root=pytorch_data_dir, image_set=self.image_set, transform=transform, target_transform=target_transform, **extra_args) if model_type_override is not None: model_type = model_type_override else: model_type = cfg.model_type nice_dataset_name = cfg.dir_dataset_name if dataset_name == "directory" else dataset_name feature_cache_file = join(pytorch_data_dir, "nns", "nns_{}_{}_{}_{}_{}.npz".format( model_type, nice_dataset_name, image_set, crop_type, cfg.res)) if pos_labels or pos_images: if not os.path.exists(feature_cache_file) or compute_knns: raise ValueError("could not find nn file {} please run precompute_knns".format(feature_cache_file)) else: loaded = np.load(feature_cache_file) self.nns = loaded["nns"] assert len(self.dataset) == self.nns.shape[0] def __len__(self): return len(self.dataset) def _set_seed(self, seed): random.seed(seed) # apply this seed to img tranfsorms torch.manual_seed(seed) # needed for torchvision 0.7 def __getitem__(self, ind): pack = self.dataset[ind] if self.pos_images or self.pos_labels: ind_pos = self.nns[ind][torch.randint(low=1, high=self.num_neighbors + 1, size=[]).item()] pack_pos = self.dataset[ind_pos] seed = np.random.randint(2147483647) # make a seed with numpy generator self._set_seed(seed) coord_entries = torch.meshgrid([torch.linspace(-1, 1, pack[0].shape[1]), torch.linspace(-1, 1, pack[0].shape[2])]) coord = torch.cat([t.unsqueeze(0) for t in coord_entries], 0) if self.extra_transform is not None: extra_trans = self.extra_transform else: extra_trans = lambda i, x: x def squeeze_tuple(label_raw): if type(label_raw) == tuple: return tuple(x.squeeze() for x in label_raw) else: return label_raw.squeeze() ret = { "ind": ind, "img": extra_trans(ind, pack[0]), "label": squeeze_tuple(extra_trans(ind, pack[1])) } if self.pos_images: ret["img_pos"] = extra_trans(ind, pack_pos[0]) ret["ind_pos"] = ind_pos if self.mask: ret["mask"] = pack[2] if self.pos_labels: ret["label_pos"] = squeeze_tuple(extra_trans(ind, pack_pos[1])) ret["mask_pos"] = pack_pos[2] if self.aug_photometric_transform is not None: img_aug = self.aug_photometric_transform(self.aug_geometric_transform(pack[0])) self._set_seed(seed) coord_aug = self.aug_geometric_transform(coord) ret["img_aug"] = img_aug ret["coord_aug"] = coord_aug.permute(1, 2, 0) return ret