AK391
files
d380b77
raw history blame
No virus
7.05 kB
import glob
import os
import cv2
import PIL.Image as Image
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
def load_image(fname, mode='RGB', return_orig=False):
img = np.array(Image.open(fname).convert(mode))
if img.ndim == 3:
img = np.transpose(img, (2, 0, 1))
out_img = img.astype('float32') / 255
if return_orig:
return out_img, img
else:
return out_img
def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod
def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')
def pad_tensor_to_modulo(img, mod):
batch_size, channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect')
def scale_image(img, factor, interpolation=cv2.INTER_AREA):
if img.shape[0] == 1:
img = img[0]
else:
img = np.transpose(img, (1, 2, 0))
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
if img.ndim == 2:
img = img[None, ...]
else:
img = np.transpose(img, (2, 0, 1))
return img
class InpaintingDataset(Dataset):
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
self.datadir = datadir
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True)))
self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames]
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor
def __len__(self):
return len(self.mask_filenames)
def __getitem__(self, i):
image = load_image(self.img_filenames[i], mode='RGB')
mask = load_image(self.mask_filenames[i], mode='L')
result = dict(image=image, mask=mask[None, ...])
if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
return result
class OurInpaintingDataset(Dataset):
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
self.datadir = datadir
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True)))
self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames]
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor
def __len__(self):
return len(self.mask_filenames)
def __getitem__(self, i):
result = dict(image=load_image(self.img_filenames[i], mode='RGB'),
mask=load_image(self.mask_filenames[i], mode='L')[None, ...])
if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor)
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
return result
class PrecomputedInpaintingResultsDataset(InpaintingDataset):
def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs):
super().__init__(datadir, **kwargs)
if not datadir.endswith('/'):
datadir += '/'
self.predictdir = predictdir
self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
for fname in self.mask_filenames]
def __getitem__(self, i):
result = super().__getitem__(i)
result['inpainted'] = load_image(self.pred_filenames[i])
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
return result
class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset):
def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs):
super().__init__(datadir, **kwargs)
if not datadir.endswith('/'):
datadir += '/'
self.predictdir = predictdir
self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}')
for fname in self.mask_filenames]
# self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
# for fname in self.mask_filenames]
def __getitem__(self, i):
result = super().__getitem__(i)
result['inpainted'] = self.file_loader(self.pred_filenames[i])
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
return result
class InpaintingEvalOnlineDataset(Dataset):
def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs):
self.indir = indir
self.mask_generator = mask_generator
self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True)))
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor
def __len__(self):
return len(self.img_filenames)
def __getitem__(self, i):
img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True)
mask = self.mask_generator(img, raw_image=raw_image)
result = dict(image=img, mask=mask)
if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
return result