| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Created in September 2022 |
| @author: fabrizio.guillaro |
| """ |
|
|
| from torch.utils.data import Dataset |
| import random |
|
|
| from dataset.dataset_FantasticReality import FantasticReality |
| from dataset.dataset_IMD2020 import IMD2020 |
| from dataset.dataset_CASIA import CASIA |
| from dataset.dataset_TampCOCO import tampCOCO |
| from dataset.dataset_CompRAISE import compRAISE |
|
|
|
|
| class myDataset(Dataset): |
| def __init__(self, config, crop_size, grid_crop, mode="train", max_dim=None, aug=None): |
| self.dataset_list = [] |
| training_set = config.DATASET.TRAIN |
| valid_set = config.DATASET.VALID |
| |
| if mode == "train": |
| if 'FR' in training_set: |
| self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_train_list.txt", aug=aug)) |
| self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_train_list.txt", is_auth_list=True, aug=aug)) |
| |
| if 'IMD' in training_set: |
| self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_train_list.txt", aug=aug)) |
| |
| if 'CA' in training_set: |
| self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_train_list.txt", aug=aug)) |
| self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_train_list.txt", aug=aug)) |
|
|
| if 'COCO' in training_set: |
| self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_train_list.txt", aug=aug)) |
| self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_train_list.txt", aug=aug)) |
| self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_train_list.txt", aug=aug)) |
| self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_train_list.txt", aug=aug)) |
| |
| if 'RAISE' in training_set: |
| self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_train.txt", aug=aug)) |
|
|
|
|
| elif mode == "valid": |
| if 'FR' in valid_set: |
| self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_valid_list.txt", max_dim=max_dim, aug=aug)) |
| self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_valid_list.txt", is_auth_list=True, max_dim=max_dim, aug=aug)) |
| |
| if 'IMD' in valid_set: |
| self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_valid_list.txt", max_dim=max_dim, aug=aug)) |
| |
| if 'CA' in valid_set: |
| self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_valid_list.txt", max_dim=max_dim, aug=aug)) |
| self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_valid_list.txt", max_dim=max_dim, aug=aug)) |
| |
| if 'COCO' in valid_set: |
| self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_valid_list.txt", max_dim=max_dim, aug=aug)) |
| self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_valid_list.txt", max_dim=max_dim, aug=aug)) |
| self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_valid_list.txt", max_dim=max_dim, aug=aug)) |
| self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_valid_list.txt", max_dim=max_dim, aug=aug)) |
| |
| if 'RAISE' in valid_set: |
| self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_valid.txt", max_dim=max_dim, aug=aug)) |
|
|
| else: |
| raise KeyError("Invalid mode: " + mode) |
|
|
| self.crop_size = crop_size |
| self.grid_crop = grid_crop |
| self.mode = mode |
| lengths = [len(ds) for ds in self.dataset_list] |
| self.smallest = min(lengths) |
| if config.TRAIN.NUM_SAMPLES > 0 and config.TRAIN.NUM_SAMPLES < self.smallest: |
| self.smallest = config.TRAIN.NUM_SAMPLES |
|
|
|
|
| def shuffle(self): |
| for dataset in self.dataset_list: |
| random.shuffle(dataset.img_list) |
|
|
|
|
| def get_filename(self, index): |
| it = 0 |
| while True: |
| if index >= len(self.dataset_list[it]): |
| index -= len(self.dataset_list[it]) |
| it += 1 |
| continue |
| return self.dataset_list[it].get_img_name(index) |
|
|
|
|
| def __len__(self): |
| if self.mode == 'train': |
| |
| return self.smallest * len(self.dataset_list) |
| else: |
| return sum([len(lst) for lst in self.dataset_list]) |
|
|
|
|
| def __getitem__(self, index): |
| if self.mode == 'train': |
| |
| if index < self.smallest * len(self.dataset_list): |
| return self.dataset_list[index//self.smallest].get_img(index % self.smallest) |
| else: |
| raise ValueError("Something wrong.") |
| else: |
| it = 0 |
| while True: |
| if index >= len(self.dataset_list[it]): |
| index -= len(self.dataset_list[it]) |
| it += 1 |
| continue |
| return self.dataset_list[it].get_img(index) |
|
|
|
|
| def get_info(self): |
| s = '' |
| for ds in self.dataset_list: |
| s += f'{ds.__class__.__name__}: \t{len(ds)} images \n' |
| s += f'Smallest: {self.smallest}' |
| return s |
|
|
|
|
|
|
|
|
|
|