# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os.path import io import zipfile from data.base_dataset import BaseDataset, get_params, get_transform, normalize from data.image_folder import make_dataset from PIL import Image import torchvision.transforms as transforms import numpy as np from data.Load_Bigfile import BigFileMemoryLoader import random import cv2 from io import BytesIO def pil_to_np(img_PIL): '''Converts image in PIL format to np.array. From W x H x C [0...255] to C x W x H [0..1] ''' ar = np.array(img_PIL) if len(ar.shape) == 3: ar = ar.transpose(2, 0, 1) else: ar = ar[None, ...] return ar.astype(np.float32) / 255. def np_to_pil(img_np): '''Converts image in np.array format to PIL image. From C x W x H [0..1] to W x H x C [0...255] ''' ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) if img_np.shape[0] == 1: ar = ar[0] else: ar = ar.transpose(1, 2, 0) return Image.fromarray(ar) def synthesize_salt_pepper(image,amount,salt_vs_pepper): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) out = img_pil.copy() p = amount q = salt_vs_pepper flipped = np.random.choice([True, False], size=img_pil.shape, p=[p, 1 - p]) salted = np.random.choice([True, False], size=img_pil.shape, p=[q, 1 - q]) peppered = ~salted out[flipped & salted] = 1 out[flipped & peppered] = 0. noisy = np.clip(out, 0, 1).astype(np.float32) return np_to_pil(noisy) def synthesize_gaussian(image,std_l,std_r): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) mean=0 std=random.uniform(std_l/255.,std_r/255.) gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) noisy=img_pil+gauss noisy=np.clip(noisy,0,1).astype(np.float32) return np_to_pil(noisy) def synthesize_speckle(image,std_l,std_r): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) mean=0 std=random.uniform(std_l/255.,std_r/255.) gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) noisy=img_pil+gauss*img_pil noisy=np.clip(noisy,0,1).astype(np.float32) return np_to_pil(noisy) def synthesize_low_resolution(img): w,h=img.size new_w=random.randint(int(w/2),w) new_h=random.randint(int(h/2),h) img=img.resize((new_w,new_h),Image.BICUBIC) if random.uniform(0,1)<0.5: img=img.resize((w,h),Image.NEAREST) else: img = img.resize((w, h), Image.BILINEAR) return img def convertToJpeg(im,quality): with BytesIO() as f: im.save(f, format='JPEG',quality=quality) f.seek(0) return Image.open(f).convert('RGB') def blur_image_v2(img): x=np.array(img) kernel_size_candidate=[(3,3),(5,5),(7,7)] kernel_size=random.sample(kernel_size_candidate,1)[0] std=random.uniform(1.,5.) #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std)) blur=cv2.GaussianBlur(x,kernel_size,std) return Image.fromarray(blur.astype(np.uint8)) def online_add_degradation_v2(img): task_id=np.random.permutation(4) for x in task_id: if x==0 and random.uniform(0,1)<0.7: img = blur_image_v2(img) if x==1 and random.uniform(0,1)<0.7: flag = random.choice([1, 2, 3]) if flag == 1: img = synthesize_gaussian(img, 5, 50) if flag == 2: img = synthesize_speckle(img, 5, 50) if flag == 3: img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8)) if x==2 and random.uniform(0,1)<0.7: img=synthesize_low_resolution(img) if x==3 and random.uniform(0,1)<0.7: img=convertToJpeg(img,random.randint(40,100)) return img def irregular_hole_synthesize(img,mask): img_np=np.array(img).astype('uint8') mask_np=np.array(mask).astype('uint8') mask_np=mask_np/255 img_new=img_np*(1-mask_np)+mask_np*255 hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB") return hole_img,mask.convert("L") def zero_mask(size): x=np.zeros((size,size,3)).astype('uint8') mask=Image.fromarray(x).convert("RGB") return mask class UnPairOldPhotos_SR(BaseDataset): ## Synthetic + Real Old def initialize(self, opt): self.opt = opt self.isImage = 'domainA' in opt.name self.task = 'old_photo_restoration_training_vae' self.dir_AB = opt.dataroot if self.isImage: self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile") self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile") self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile") self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old) self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old) self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean) else: # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset) self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile") self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean) #### print("-------------Filter the imgs whose size <256 in VOC-------------") self.filtered_imgs_clean=[] for i in range(len(self.loaded_imgs_clean)): img_name,img=self.loaded_imgs_clean[i] h,w=img.size if h<256 or w<256: continue self.filtered_imgs_clean.append((img_name,img)) print("--------Origin image num is [%d], filtered result is [%d]--------" % ( len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) ## Filter these images whose size is less than 256 # self.img_list=os.listdir(load_img_dir) self.pid = os.getpid() def __getitem__(self, index): is_real_old=0 sampled_dataset=None degradation=None if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old P=random.uniform(0,2) if P>=0 and P<1: if random.uniform(0,1)<0.5: sampled_dataset=self.loaded_imgs_L_old self.load_img_dir=self.load_img_dir_L_old else: sampled_dataset=self.loaded_imgs_RGB_old self.load_img_dir=self.load_img_dir_RGB_old is_real_old=1 if P>=1 and P<2: sampled_dataset=self.filtered_imgs_clean self.load_img_dir=self.load_img_dir_clean degradation=1 else: sampled_dataset=self.filtered_imgs_clean self.load_img_dir=self.load_img_dir_clean sampled_dataset_len=len(sampled_dataset) index=random.randint(0,sampled_dataset_len-1) img_name,img = sampled_dataset[index] if degradation is not None: img=online_add_degradation_v2(img) path=os.path.join(self.load_img_dir,img_name) # AB = Image.open(path).convert('RGB') # split AB image into A and B # apply the same transform to both A and B if random.uniform(0,1) <0.1: img=img.convert("L") img=img.convert("RGB") ## Give a probability P, we convert the RGB image into L A=img w,h=A.size if w<256 or h<256: A=transforms.Scale(256,Image.BICUBIC)(A) ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them. transform_params = get_params(self.opt, A.size) A_transform = get_transform(self.opt, transform_params) B_tensor = inst_tensor = feat_tensor = 0 A_tensor = A_transform(A) input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor, 'feat': feat_tensor, 'path': path} return input_dict def __len__(self): return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number def name(self): return 'UnPairOldPhotos_SR' class PairOldPhotos(BaseDataset): def initialize(self, opt): self.opt = opt self.isImage = 'imagegan' in opt.name self.task = 'old_photo_restoration_training_mapping' self.dir_AB = opt.dataroot if opt.isTrain: self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile") self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean) print("-------------Filter the imgs whose size <256 in VOC-------------") self.filtered_imgs_clean = [] for i in range(len(self.loaded_imgs_clean)): img_name, img = self.loaded_imgs_clean[i] h, w = img.size if h < 256 or w < 256: continue self.filtered_imgs_clean.append((img_name, img)) print("--------Origin image num is [%d], filtered result is [%d]--------" % ( len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) else: self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset) self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir) self.pid = os.getpid() def __getitem__(self, index): if self.opt.isTrain: img_name_clean,B = self.filtered_imgs_clean[index] path = os.path.join(self.load_img_dir_clean, img_name_clean) if self.opt.use_v2_degradation: A=online_add_degradation_v2(B) ### Remind: A is the input and B is corresponding GT else: if self.opt.test_on_synthetic: img_name_B,B=self.loaded_imgs[index] A=online_add_degradation_v2(B) img_name_A=img_name_B path = os.path.join(self.load_img_dir, img_name_A) else: img_name_A,A=self.loaded_imgs[index] img_name_B,B=self.loaded_imgs[index] path = os.path.join(self.load_img_dir, img_name_A) if random.uniform(0,1)<0.1 and self.opt.isTrain: A=A.convert("L") B=B.convert("L") A=A.convert("RGB") B=B.convert("RGB") ## In P, we convert the RGB into L ##test on L # split AB image into A and B # w, h = img.size # w2 = int(w / 2) # A = img.crop((0, 0, w2, h)) # B = img.crop((w2, 0, w, h)) w,h=A.size if w<256 or h<256: A=transforms.Scale(256,Image.BICUBIC)(A) B=transforms.Scale(256, Image.BICUBIC)(B) # apply the same transform to both A and B transform_params = get_params(self.opt, A.size) A_transform = get_transform(self.opt, transform_params) B_transform = get_transform(self.opt, transform_params) B_tensor = inst_tensor = feat_tensor = 0 A_tensor = A_transform(A) B_tensor = B_transform(B) input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 'feat': feat_tensor, 'path': path} return input_dict def __len__(self): if self.opt.isTrain: return len(self.filtered_imgs_clean) else: return len(self.loaded_imgs) def name(self): return 'PairOldPhotos' class PairOldPhotos_with_hole(BaseDataset): def initialize(self, opt): self.opt = opt self.isImage = 'imagegan' in opt.name self.task = 'old_photo_restoration_training_mapping' self.dir_AB = opt.dataroot if opt.isTrain: self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile") self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean) print("-------------Filter the imgs whose size <256 in VOC-------------") self.filtered_imgs_clean = [] for i in range(len(self.loaded_imgs_clean)): img_name, img = self.loaded_imgs_clean[i] h, w = img.size if h < 256 or w < 256: continue self.filtered_imgs_clean.append((img_name, img)) print("--------Origin image num is [%d], filtered result is [%d]--------" % ( len(self.loaded_imgs_clean), len(self.filtered_imgs_clean))) else: self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset) self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir) self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask) self.pid = os.getpid() def __getitem__(self, index): if self.opt.isTrain: img_name_clean,B = self.filtered_imgs_clean[index] path = os.path.join(self.load_img_dir_clean, img_name_clean) B=transforms.RandomCrop(256)(B) A=online_add_degradation_v2(B) ### Remind: A is the input and B is corresponding GT else: img_name_A,A=self.loaded_imgs[index] img_name_B,B=self.loaded_imgs[index] path = os.path.join(self.load_img_dir, img_name_A) #A=A.resize((256,256)) A=transforms.CenterCrop(256)(A) B=A if random.uniform(0,1)<0.1 and self.opt.isTrain: A=A.convert("L") B=B.convert("L") A=A.convert("RGB") B=B.convert("RGB") ## In P, we convert the RGB into L if self.opt.isTrain: mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)] else: mask_name, mask = self.loaded_masks[index%100] mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST) if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain: mask=zero_mask(256) if self.opt.no_hole: mask=zero_mask(256) A,_=irregular_hole_synthesize(A,mask) if not self.opt.isTrain and self.opt.hole_image_no_mask: mask=zero_mask(256) transform_params = get_params(self.opt, A.size) A_transform = get_transform(self.opt, transform_params) B_transform = get_transform(self.opt, transform_params) if transform_params['flip'] and self.opt.isTrain: mask=mask.transpose(Image.FLIP_LEFT_RIGHT) mask_tensor = transforms.ToTensor()(mask) B_tensor = inst_tensor = feat_tensor = 0 A_tensor = A_transform(A) B_tensor = B_transform(B) input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor, 'feat': feat_tensor, 'path': path} return input_dict def __len__(self): if self.opt.isTrain: return len(self.filtered_imgs_clean) else: return len(self.loaded_imgs) def name(self): return 'PairOldPhotos_with_hole'