Image_Restoration_Colorization / Global /data /online_dataset_for_old_photos.py
manhkhanhUIT's picture
Init
e78c13e
raw
history blame
No virus
15.3 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os.path
import io
import zipfile
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from data.Load_Bigfile import BigFileMemoryLoader
import random
import cv2
from io import BytesIO
def pil_to_np(img_PIL):
'''Converts image in PIL format to np.array.
From W x H x C [0...255] to C x W x H [0..1]
'''
ar = np.array(img_PIL)
if len(ar.shape) == 3:
ar = ar.transpose(2, 0, 1)
else:
ar = ar[None, ...]
return ar.astype(np.float32) / 255.
def np_to_pil(img_np):
'''Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
'''
ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)
def synthesize_salt_pepper(image,amount,salt_vs_pepper):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
out = img_pil.copy()
p = amount
q = salt_vs_pepper
flipped = np.random.choice([True, False], size=img_pil.shape,
p=[p, 1 - p])
salted = np.random.choice([True, False], size=img_pil.shape,
p=[q, 1 - q])
peppered = ~salted
out[flipped & salted] = 1
out[flipped & peppered] = 0.
noisy = np.clip(out, 0, 1).astype(np.float32)
return np_to_pil(noisy)
def synthesize_gaussian(image,std_l,std_r):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
mean=0
std=random.uniform(std_l/255.,std_r/255.)
gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
noisy=img_pil+gauss
noisy=np.clip(noisy,0,1).astype(np.float32)
return np_to_pil(noisy)
def synthesize_speckle(image,std_l,std_r):
## Give PIL, return the noisy PIL
img_pil=pil_to_np(image)
mean=0
std=random.uniform(std_l/255.,std_r/255.)
gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
noisy=img_pil+gauss*img_pil
noisy=np.clip(noisy,0,1).astype(np.float32)
return np_to_pil(noisy)
def synthesize_low_resolution(img):
w,h=img.size
new_w=random.randint(int(w/2),w)
new_h=random.randint(int(h/2),h)
img=img.resize((new_w,new_h),Image.BICUBIC)
if random.uniform(0,1)<0.5:
img=img.resize((w,h),Image.NEAREST)
else:
img = img.resize((w, h), Image.BILINEAR)
return img
def convertToJpeg(im,quality):
with BytesIO() as f:
im.save(f, format='JPEG',quality=quality)
f.seek(0)
return Image.open(f).convert('RGB')
def blur_image_v2(img):
x=np.array(img)
kernel_size_candidate=[(3,3),(5,5),(7,7)]
kernel_size=random.sample(kernel_size_candidate,1)[0]
std=random.uniform(1.,5.)
#print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std))
blur=cv2.GaussianBlur(x,kernel_size,std)
return Image.fromarray(blur.astype(np.uint8))
def online_add_degradation_v2(img):
task_id=np.random.permutation(4)
for x in task_id:
if x==0 and random.uniform(0,1)<0.7:
img = blur_image_v2(img)
if x==1 and random.uniform(0,1)<0.7:
flag = random.choice([1, 2, 3])
if flag == 1:
img = synthesize_gaussian(img, 5, 50)
if flag == 2:
img = synthesize_speckle(img, 5, 50)
if flag == 3:
img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))
if x==2 and random.uniform(0,1)<0.7:
img=synthesize_low_resolution(img)
if x==3 and random.uniform(0,1)<0.7:
img=convertToJpeg(img,random.randint(40,100))
return img
def irregular_hole_synthesize(img,mask):
img_np=np.array(img).astype('uint8')
mask_np=np.array(mask).astype('uint8')
mask_np=mask_np/255
img_new=img_np*(1-mask_np)+mask_np*255
hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB")
return hole_img,mask.convert("L")
def zero_mask(size):
x=np.zeros((size,size,3)).astype('uint8')
mask=Image.fromarray(x).convert("RGB")
return mask
class UnPairOldPhotos_SR(BaseDataset): ## Synthetic + Real Old
def initialize(self, opt):
self.opt = opt
self.isImage = 'domainA' in opt.name
self.task = 'old_photo_restoration_training_vae'
self.dir_AB = opt.dataroot
if self.isImage:
self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile")
self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile")
self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old)
self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old)
self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
else:
# self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset)
self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)
####
print("-------------Filter the imgs whose size <256 in VOC-------------")
self.filtered_imgs_clean=[]
for i in range(len(self.loaded_imgs_clean)):
img_name,img=self.loaded_imgs_clean[i]
h,w=img.size
if h<256 or w<256:
continue
self.filtered_imgs_clean.append((img_name,img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
## Filter these images whose size is less than 256
# self.img_list=os.listdir(load_img_dir)
self.pid = os.getpid()
def __getitem__(self, index):
is_real_old=0
sampled_dataset=None
degradation=None
if self.isImage: ## domain A , contains 2 kinds of data: synthetic + real_old
P=random.uniform(0,2)
if P>=0 and P<1:
if random.uniform(0,1)<0.5:
sampled_dataset=self.loaded_imgs_L_old
self.load_img_dir=self.load_img_dir_L_old
else:
sampled_dataset=self.loaded_imgs_RGB_old
self.load_img_dir=self.load_img_dir_RGB_old
is_real_old=1
if P>=1 and P<2:
sampled_dataset=self.filtered_imgs_clean
self.load_img_dir=self.load_img_dir_clean
degradation=1
else:
sampled_dataset=self.filtered_imgs_clean
self.load_img_dir=self.load_img_dir_clean
sampled_dataset_len=len(sampled_dataset)
index=random.randint(0,sampled_dataset_len-1)
img_name,img = sampled_dataset[index]
if degradation is not None:
img=online_add_degradation_v2(img)
path=os.path.join(self.load_img_dir,img_name)
# AB = Image.open(path).convert('RGB')
# split AB image into A and B
# apply the same transform to both A and B
if random.uniform(0,1) <0.1:
img=img.convert("L")
img=img.convert("RGB")
## Give a probability P, we convert the RGB image into L
A=img
w,h=A.size
if w<256 or h<256:
A=transforms.Scale(256,Image.BICUBIC)(A)
## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them.
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number
def name(self):
return 'UnPairOldPhotos_SR'
class PairOldPhotos(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.isImage = 'imagegan' in opt.name
self.task = 'old_photo_restoration_training_mapping'
self.dir_AB = opt.dataroot
if opt.isTrain:
self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
print("-------------Filter the imgs whose size <256 in VOC-------------")
self.filtered_imgs_clean = []
for i in range(len(self.loaded_imgs_clean)):
img_name, img = self.loaded_imgs_clean[i]
h, w = img.size
if h < 256 or w < 256:
continue
self.filtered_imgs_clean.append((img_name, img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
else:
self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
self.pid = os.getpid()
def __getitem__(self, index):
if self.opt.isTrain:
img_name_clean,B = self.filtered_imgs_clean[index]
path = os.path.join(self.load_img_dir_clean, img_name_clean)
if self.opt.use_v2_degradation:
A=online_add_degradation_v2(B)
### Remind: A is the input and B is corresponding GT
else:
if self.opt.test_on_synthetic:
img_name_B,B=self.loaded_imgs[index]
A=online_add_degradation_v2(B)
img_name_A=img_name_B
path = os.path.join(self.load_img_dir, img_name_A)
else:
img_name_A,A=self.loaded_imgs[index]
img_name_B,B=self.loaded_imgs[index]
path = os.path.join(self.load_img_dir, img_name_A)
if random.uniform(0,1)<0.1 and self.opt.isTrain:
A=A.convert("L")
B=B.convert("L")
A=A.convert("RGB")
B=B.convert("RGB")
## In P, we convert the RGB into L
##test on L
# split AB image into A and B
# w, h = img.size
# w2 = int(w / 2)
# A = img.crop((0, 0, w2, h))
# B = img.crop((w2, 0, w, h))
w,h=A.size
if w<256 or h<256:
A=transforms.Scale(256,Image.BICUBIC)(A)
B=transforms.Scale(256, Image.BICUBIC)(B)
# apply the same transform to both A and B
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params)
B_transform = get_transform(self.opt, transform_params)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
B_tensor = B_transform(B)
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
if self.opt.isTrain:
return len(self.filtered_imgs_clean)
else:
return len(self.loaded_imgs)
def name(self):
return 'PairOldPhotos'
class PairOldPhotos_with_hole(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.isImage = 'imagegan' in opt.name
self.task = 'old_photo_restoration_training_mapping'
self.dir_AB = opt.dataroot
if opt.isTrain:
self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)
print("-------------Filter the imgs whose size <256 in VOC-------------")
self.filtered_imgs_clean = []
for i in range(len(self.loaded_imgs_clean)):
img_name, img = self.loaded_imgs_clean[i]
h, w = img.size
if h < 256 or w < 256:
continue
self.filtered_imgs_clean.append((img_name, img))
print("--------Origin image num is [%d], filtered result is [%d]--------" % (
len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
else:
self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask)
self.pid = os.getpid()
def __getitem__(self, index):
if self.opt.isTrain:
img_name_clean,B = self.filtered_imgs_clean[index]
path = os.path.join(self.load_img_dir_clean, img_name_clean)
B=transforms.RandomCrop(256)(B)
A=online_add_degradation_v2(B)
### Remind: A is the input and B is corresponding GT
else:
img_name_A,A=self.loaded_imgs[index]
img_name_B,B=self.loaded_imgs[index]
path = os.path.join(self.load_img_dir, img_name_A)
#A=A.resize((256,256))
A=transforms.CenterCrop(256)(A)
B=A
if random.uniform(0,1)<0.1 and self.opt.isTrain:
A=A.convert("L")
B=B.convert("L")
A=A.convert("RGB")
B=B.convert("RGB")
## In P, we convert the RGB into L
if self.opt.isTrain:
mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)]
else:
mask_name, mask = self.loaded_masks[index%100]
mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST)
if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain:
mask=zero_mask(256)
if self.opt.no_hole:
mask=zero_mask(256)
A,_=irregular_hole_synthesize(A,mask)
if not self.opt.isTrain and self.opt.hole_image_no_mask:
mask=zero_mask(256)
transform_params = get_params(self.opt, A.size)
A_transform = get_transform(self.opt, transform_params)
B_transform = get_transform(self.opt, transform_params)
if transform_params['flip'] and self.opt.isTrain:
mask=mask.transpose(Image.FLIP_LEFT_RIGHT)
mask_tensor = transforms.ToTensor()(mask)
B_tensor = inst_tensor = feat_tensor = 0
A_tensor = A_transform(A)
B_tensor = B_transform(B)
input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor,
'feat': feat_tensor, 'path': path}
return input_dict
def __len__(self):
if self.opt.isTrain:
return len(self.filtered_imgs_clean)
else:
return len(self.loaded_imgs)
def name(self):
return 'PairOldPhotos_with_hole'