import os import json from PIL import Image import pickle import imageio import numpy as np import torch from torch.utils.data import Dataset from torchvision import transforms import random from datasets import register import math import torch.distributed as dist from torch.utils.data import BatchSampler from torch.utils.data._utils.collate import default_collate @register('image-folder') class ImageFolder(Dataset): def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, repeat=1, cache='none', mask=False): self.repeat = repeat self.cache = cache self.path = path self.Train = False self.split_key = split_key self.size = size self.mask = mask if self.mask: self.img_transform = transforms.Compose([ transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), transforms.ToTensor(), ]) else: self.img_transform = transforms.Compose([ transforms.Resize((self.size, self.size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if split_file is None: filenames = sorted(os.listdir(path)) else: with open(split_file, 'r') as f: filenames = json.load(f)[split_key] if first_k is not None: filenames = filenames[:first_k] self.files = [] for filename in filenames: file = os.path.join(path, filename) self.append_file(file) def append_file(self, file): if self.cache == 'none': self.files.append(file) elif self.cache == 'in_memory': self.files.append(self.img_process(file)) def __len__(self): return len(self.files) * self.repeat def __getitem__(self, idx): x = self.files[idx % len(self.files)] if self.cache == 'none': return self.img_process(x) elif self.cache == 'in_memory': return x def img_process(self, file): if self.mask: # return Image.open(file).convert('L') return file else: return Image.open(file).convert('RGB') @register('paired-image-folders') class PairedImageFolders(Dataset): def __init__(self, root_path_1, root_path_2, **kwargs): self.dataset_1 = ImageFolder(root_path_1, **kwargs) self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True) def __len__(self): return len(self.dataset_1) def __getitem__(self, idx): return self.dataset_1[idx], self.dataset_2[idx] class ImageFolder_multi_task(Dataset): def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, repeat=1, cache='none', mask=False): self.repeat = repeat self.cache = cache self.path = path self.Train = False self.split_key = split_key self.size = size self.mask = mask if self.mask: self.img_transform = transforms.Compose([ transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), transforms.ToTensor(), ]) else: self.img_transform = transforms.Compose([ transforms.Resize((self.size, self.size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if split_file is None: filenames = sorted(os.listdir(path)) else: with open(split_file, 'r') as f: filenames = json.load(f)[split_key] if first_k is not None: filenames = filenames[:first_k] self.files = [] for filename in filenames: file = os.path.join(path, filename) self.append_file(file) def append_file(self, file): if self.cache == 'none': self.files.append(file) elif self.cache == 'in_memory': self.files.append(self.img_process(file)) def __len__(self): return len(self.files) * self.repeat def __getitem__(self, idx): x = self.files[idx % len(self.files)] if self.cache == 'none': return self.img_process(x) elif self.cache == 'in_memory': return x def img_process(self, file): if self.mask: # return Image.open(file).convert('L') return file else: return Image.open(file).convert('RGB') @register('paired-image-folders-multi-task') class PairedImageFolders_multi_task(Dataset): def __init__(self, root_path_1, root_path_2, model=None, **kwargs): self.dataset_1 = ImageFolder_multi_task(root_path_1, **kwargs) self.dataset_2 = ImageFolder_multi_task(root_path_2, **kwargs, mask=True) def __len__(self): return len(self.dataset_1) def __getitem__(self, idx): return self.dataset_1[idx], self.dataset_2[idx] # class MultiTaskDataset(Dataset): # """ # useage example: # train_datasets = [SemData_Single(), SemData_Single()] # multi_task_train_dataset = MultiTaskDataset(train_datasets) # multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) # multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) # for i, (task_id, input, target) in enumerate(multi_task_train_data): # pre = model(input) # """ # def __init__(self, datasets_image, datasets_gt): # self._datasets = datasets_image # task_id_2_image_set_dic = {} # for i, dataset in enumerate(datasets_image): # task_id = i # assert task_id not in task_id_2_image_set_dic, "Duplicate task_id %s" % task_id # task_id_2_image_set_dic[task_id] = dataset # self.datasets_1 = task_id_2_image_set_dic # # task_id_2_gt_set_dic = {} # for i, dataset in enumerate(datasets_gt): # task_id = i # assert task_id not in task_id_2_gt_set_dic, "Duplicate task_id %s" % task_id # task_id_2_gt_set_dic[task_id] = dataset # self.dataset_2 = task_id_2_gt_set_dic # # # def __len__(self): # return sum(len(dataset) for dataset in self._datasets) # # def __getitem__(self, idx): # task_id, sample_id = idx # # return self._task_id_2_data_set_dic[task_id][sample_id] # return self.dataset_1[task_id][sample_id], self.dataset_2[task_id][sample_id] class MultiTaskDataset(Dataset): """ useage example: train_datasets = [SemData_Single(), SemData_Single()] multi_task_train_dataset = MultiTaskDataset(train_datasets) multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) for i, (task_id, input, target) in enumerate(multi_task_train_data): pre = model(input) """ def __init__(self, datasets): self._datasets = datasets task_id_2_data_set_dic = {} for i, dataset in enumerate(datasets): task_id = i assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id task_id_2_data_set_dic[task_id] = dataset self._task_id_2_data_set_dic = task_id_2_data_set_dic def __len__(self): return sum(len(dataset) for dataset in self._datasets) def __getitem__(self, idx): task_id, sample_id = idx # print('----', idx, task_id, sample_id) return self._task_id_2_data_set_dic[task_id][sample_id] def collate_fn(batch): # print(len(batch)) # print('*'*10) # print(batch[0][0]) # print('#'*10) # print(batch[0][1]) # batch = list(filter(lambda x: x[0][0] is not None, batch)) # if len(batch) == 0: return torch.Tensor() print('******------',batch) if not isinstance(batch[0][0], tuple): return default_collate(batch) else: batch_num = len(batch) ret = [] for item_idx in range(len(batch[0][0])): if batch[0][0][item_idx] is None: ret.append(None) else: ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) ret.append(default_collate([batch[i][1] for i in range(batch_num)])) return ret class DistrubutedMultiTaskBatchSampler(BatchSampler): """ datasets: class the class of the Dataset batch_size: int mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task extra_task_ratio(float, optional): the rate between task one and extra task drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: ``True``) """ def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True, shuffle=True): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError( "Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1)) self.num_replicas = num_replicas self.rank = rank self.epoch = 0 assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1' assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0' # self._datasets = datasets self._batch_size = batch_size self._mix_opt = mix_opt self._extra_task_ratio = extra_task_ratio self._drop_last = drop_last train_data_list = [] self.shuffle = shuffle for dataset in datasets: print(len(dataset)) train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last)) ######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index self._train_data_list = train_data_list self.total_len = sum(len(train_data) for train_data in self._train_data_list) ######### DDP ###################### if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type] self.num_samples = math.ceil( (self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.epoch = 0 self.seed = 0 def set_epoch(self, epoch): # print('&&&&****') self.epoch = epoch @staticmethod def _get_index_batches(dataset_len, batch_size, drop_last): # index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)] index = list(range(dataset_len)) if drop_last and dataset_len % batch_size: del index[-(dataset_len % batch_size):] index_batches = [index[i:i + batch_size] for i in range(0, len(index), batch_size)] return index_batches def __len__(self): # return sum(len(train_data) for train_data in self._train_data_list) return self.num_samples def __iter__(self): all_iters = [iter(item) for item in self._train_data_list] all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio) ######### DDP ###################### random.shuffle(all_indices) all_indices = all_indices[self.rank:self.total_size:self.num_replicas] assert len(all_indices) == self.num_samples for local_task_idx in all_indices: # task_id = self._datasets[local_task_idx].get_task_id() batch = next(all_iters[local_task_idx]) # batch = batch[self.rank:len(batch):self.num_replicas] # print(local_task_idx) yield [(local_task_idx, sample_id) for sample_id in batch] # yield iter(batch) @staticmethod def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio): ########## accoding to the number of models ########### all_indices = [] for i in range(len(train_data_list)): all_indices += [i] * len(train_data_list[i]) # print(all_indices) return all_indices # def set_epoch(self, epoch) # r""" # Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas # use a different random ordering for each epoch. Otherwise, the next iteration of this # sampler will yield the same ordering. # Args: # epoch (int): Epoch number. # """ # self.epoch = epoch