import glob import logging import os import random import albumentations as A import cv2 import numpy as np import torch import torch.nn.functional as F import webdataset from omegaconf import open_dict, OmegaConf from skimage.feature import canny from skimage.transform import rescale, resize from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, ConcatDataset from saicinpainting.evaluation.data import InpaintingDataset as InpaintingEvaluationDataset, \ OurInpaintingDataset as OurInpaintingEvaluationDataset, ceil_modulo, InpaintingEvalOnlineDataset from saicinpainting.training.data.aug import IAAAffine2, IAAPerspective2 from saicinpainting.training.data.masks import get_mask_generator LOGGER = logging.getLogger(__name__) class InpaintingTrainDataset(Dataset): def __init__(self, indir, mask_generator, transform): self.in_files = list(glob.glob(os.path.join(indir, '**', '*.jpg'), recursive=True)) self.mask_generator = mask_generator self.transform = transform self.iter_i = 0 def __len__(self): return len(self.in_files) def __getitem__(self, item): path = self.in_files[item] img = cv2.imread(path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = self.transform(image=img)['image'] img = np.transpose(img, (2, 0, 1)) # TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks mask = self.mask_generator(img, iter_i=self.iter_i) self.iter_i += 1 return dict(image=img, mask=mask) class InpaintingTrainWebDataset(IterableDataset): def __init__(self, indir, mask_generator, transform, shuffle_buffer=200): self.impl = webdataset.Dataset(indir).shuffle(shuffle_buffer).decode('rgb').to_tuple('jpg') self.mask_generator = mask_generator self.transform = transform def __iter__(self): for iter_i, (img,) in enumerate(self.impl): img = np.clip(img * 255, 0, 255).astype('uint8') img = self.transform(image=img)['image'] img = np.transpose(img, (2, 0, 1)) mask = self.mask_generator(img, iter_i=iter_i) yield dict(image=img, mask=mask) class ImgSegmentationDataset(Dataset): def __init__(self, indir, mask_generator, transform, out_size, segm_indir, semantic_seg_n_classes): self.indir = indir self.segm_indir = segm_indir self.mask_generator = mask_generator self.transform = transform self.out_size = out_size self.semantic_seg_n_classes = semantic_seg_n_classes self.in_files = list(glob.glob(os.path.join(indir, '**', '*.jpg'), recursive=True)) def __len__(self): return len(self.in_files) def __getitem__(self, item): path = self.in_files[item] img = cv2.imread(path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (self.out_size, self.out_size)) img = self.transform(image=img)['image'] img = np.transpose(img, (2, 0, 1)) mask = self.mask_generator(img) segm, segm_classes= self.load_semantic_segm(path) result = dict(image=img, mask=mask, segm=segm, segm_classes=segm_classes) return result def load_semantic_segm(self, img_path): segm_path = img_path.replace(self.indir, self.segm_indir).replace(".jpg", ".png") mask = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE) mask = cv2.resize(mask, (self.out_size, self.out_size)) tensor = torch.from_numpy(np.clip(mask.astype(int)-1, 0, None)) ohe = F.one_hot(tensor.long(), num_classes=self.semantic_seg_n_classes) # w x h x n_classes return ohe.permute(2, 0, 1).float(), tensor.unsqueeze(0) def get_transforms(transform_variant, out_size): if transform_variant == 'default': transform = A.Compose([ A.RandomScale(scale_limit=0.2), # +/- 20% A.PadIfNeeded(min_height=out_size, min_width=out_size), A.RandomCrop(height=out_size, width=out_size), A.HorizontalFlip(), A.CLAHE(), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), A.ToFloat() ]) elif transform_variant == 'distortions': transform = A.Compose([ IAAPerspective2(scale=(0.0, 0.06)), IAAAffine2(scale=(0.7, 1.3), rotate=(-40, 40), shear=(-0.1, 0.1)), A.PadIfNeeded(min_height=out_size, min_width=out_size), A.OpticalDistortion(), A.RandomCrop(height=out_size, width=out_size), A.HorizontalFlip(), A.CLAHE(), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), A.ToFloat() ]) elif transform_variant == 'distortions_scale05_1': transform = A.Compose([ IAAPerspective2(scale=(0.0, 0.06)), IAAAffine2(scale=(0.5, 1.0), rotate=(-40, 40), shear=(-0.1, 0.1), p=1), A.PadIfNeeded(min_height=out_size, min_width=out_size), A.OpticalDistortion(), A.RandomCrop(height=out_size, width=out_size), A.HorizontalFlip(), A.CLAHE(), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), A.ToFloat() ]) elif transform_variant == 'distortions_scale03_12': transform = A.Compose([ IAAPerspective2(scale=(0.0, 0.06)), IAAAffine2(scale=(0.3, 1.2), rotate=(-40, 40), shear=(-0.1, 0.1), p=1), A.PadIfNeeded(min_height=out_size, min_width=out_size), A.OpticalDistortion(), A.RandomCrop(height=out_size, width=out_size), A.HorizontalFlip(), A.CLAHE(), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), A.ToFloat() ]) elif transform_variant == 'distortions_scale03_07': transform = A.Compose([ IAAPerspective2(scale=(0.0, 0.06)), IAAAffine2(scale=(0.3, 0.7), # scale 512 to 256 in average rotate=(-40, 40), shear=(-0.1, 0.1), p=1), A.PadIfNeeded(min_height=out_size, min_width=out_size), A.OpticalDistortion(), A.RandomCrop(height=out_size, width=out_size), A.HorizontalFlip(), A.CLAHE(), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), A.ToFloat() ]) elif transform_variant == 'distortions_light': transform = A.Compose([ IAAPerspective2(scale=(0.0, 0.02)), IAAAffine2(scale=(0.8, 1.8), rotate=(-20, 20), shear=(-0.03, 0.03)), A.PadIfNeeded(min_height=out_size, min_width=out_size), A.RandomCrop(height=out_size, width=out_size), A.HorizontalFlip(), A.CLAHE(), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), A.ToFloat() ]) elif transform_variant == 'non_space_transform': transform = A.Compose([ A.CLAHE(), A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), A.ToFloat() ]) elif transform_variant == 'no_augs': transform = A.Compose([ A.ToFloat() ]) else: raise ValueError(f'Unexpected transform_variant {transform_variant}') return transform def make_default_train_dataloader(indir, kind='default', out_size=512, mask_gen_kwargs=None, transform_variant='default', mask_generator_kind="mixed", dataloader_kwargs=None, ddp_kwargs=None, **kwargs): LOGGER.info(f'Make train dataloader {kind} from {indir}. Using mask generator={mask_generator_kind}') mask_generator = get_mask_generator(kind=mask_generator_kind, kwargs=mask_gen_kwargs) transform = get_transforms(transform_variant, out_size) if kind == 'default': dataset = InpaintingTrainDataset(indir=indir, mask_generator=mask_generator, transform=transform, **kwargs) elif kind == 'default_web': dataset = InpaintingTrainWebDataset(indir=indir, mask_generator=mask_generator, transform=transform, **kwargs) elif kind == 'img_with_segm': dataset = ImgSegmentationDataset(indir=indir, mask_generator=mask_generator, transform=transform, out_size=out_size, **kwargs) else: raise ValueError(f'Unknown train dataset kind {kind}') if dataloader_kwargs is None: dataloader_kwargs = {} is_dataset_only_iterable = kind in ('default_web',) if ddp_kwargs is not None and not is_dataset_only_iterable: dataloader_kwargs['shuffle'] = False dataloader_kwargs['sampler'] = DistributedSampler(dataset, **ddp_kwargs) if is_dataset_only_iterable and 'shuffle' in dataloader_kwargs: with open_dict(dataloader_kwargs): del dataloader_kwargs['shuffle'] dataloader = DataLoader(dataset, **dataloader_kwargs) return dataloader def make_default_val_dataset(indir, kind='default', out_size=512, transform_variant='default', **kwargs): if OmegaConf.is_list(indir) or isinstance(indir, (tuple, list)): return ConcatDataset([ make_default_val_dataset(idir, kind=kind, out_size=out_size, transform_variant=transform_variant, **kwargs) for idir in indir ]) LOGGER.info(f'Make val dataloader {kind} from {indir}') mask_generator = get_mask_generator(kind=kwargs.get("mask_generator_kind"), kwargs=kwargs.get("mask_gen_kwargs")) if transform_variant is not None: transform = get_transforms(transform_variant, out_size) if kind == 'default': dataset = InpaintingEvaluationDataset(indir, **kwargs) elif kind == 'our_eval': dataset = OurInpaintingEvaluationDataset(indir, **kwargs) elif kind == 'img_with_segm': dataset = ImgSegmentationDataset(indir=indir, mask_generator=mask_generator, transform=transform, out_size=out_size, **kwargs) elif kind == 'online': dataset = InpaintingEvalOnlineDataset(indir=indir, mask_generator=mask_generator, transform=transform, out_size=out_size, **kwargs) else: raise ValueError(f'Unknown val dataset kind {kind}') return dataset def make_default_val_dataloader(*args, dataloader_kwargs=None, **kwargs): dataset = make_default_val_dataset(*args, **kwargs) if dataloader_kwargs is None: dataloader_kwargs = {} dataloader = DataLoader(dataset, **dataloader_kwargs) return dataloader def make_constant_area_crop_params(img_height, img_width, min_size=128, max_size=512, area=256*256, round_to_mod=16): min_size = min(img_height, img_width, min_size) max_size = min(img_height, img_width, max_size) if random.random() < 0.5: out_height = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod)) out_width = min(max_size, ceil_modulo(area // out_height, round_to_mod)) else: out_width = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod)) out_height = min(max_size, ceil_modulo(area // out_width, round_to_mod)) start_y = random.randint(0, img_height - out_height) start_x = random.randint(0, img_width - out_width) return (start_y, start_x, out_height, out_width)