from PIL import Image import numpy as np from skimage import color import torch import torch.nn.functional as F from IPython import embed def load_img(img_path): out_np = np.asarray(Image.open(img_path)) if(out_np.ndim==2): out_np = np.tile(out_np[:,:,None],3) return out_np def resize_img(img, HW=(256,256), resample=3): return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample)) def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): # return original size L and resized L as torch Tensors img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) img_lab_orig = color.rgb2lab(img_rgb_orig) img_lab_rs = color.rgb2lab(img_rgb_rs) img_l_orig = img_lab_orig[:,:,0] img_l_rs = img_lab_rs[:,:,0] tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:] tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:] return (tens_orig_l, tens_rs_l) def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): # tens_orig_l 1 x 1 x H_orig x W_orig # out_ab 1 x 2 x H x W HW_orig = tens_orig_l.shape[2:] HW = out_ab.shape[2:] # call resize function if needed if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]): out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear') else: out_ab_orig = out_ab out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))