Spaces:
Running
Running
from PIL import Image | |
import numpy as np | |
from skimage import color | |
import torch | |
import torch.nn.functional as F | |
from IPython import embed | |
def load_img(img_path): | |
out_np = np.asarray(Image.open(img_path)) | |
if(out_np.ndim==2): | |
out_np = np.tile(out_np[:,:,None],3) | |
return out_np | |
def resize_img(img, HW=(256,256), resample=3): | |
return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample)) | |
def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): | |
# return original size L and resized L as torch Tensors | |
img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) | |
img_lab_orig = color.rgb2lab(img_rgb_orig) | |
img_lab_rs = color.rgb2lab(img_rgb_rs) | |
img_l_orig = img_lab_orig[:,:,0] | |
img_l_rs = img_lab_rs[:,:,0] | |
tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:] | |
tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:] | |
return (tens_orig_l, tens_rs_l) | |
def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): | |
# tens_orig_l 1 x 1 x H_orig x W_orig | |
# out_ab 1 x 2 x H x W | |
HW_orig = tens_orig_l.shape[2:] | |
HW = out_ab.shape[2:] | |
# call resize function if needed | |
if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]): | |
out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear') | |
else: | |
out_ab_orig = out_ab | |
out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) | |
return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0))) | |