UnderWater / util /get_transform.py
Yarflam's picture
Fix inference + scaling, update Gradio
a8eef7d
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'fixsize' in opt.preprocess:
transform_list.append(transforms.Resize(params["size"], method))
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
if "gta2cityscapes" in opt.dataroot:
osize[0] = opt.load_size // 2
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
elif 'scale_shortside' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
if opt.preprocess == 'yarflam_auto':
transform_list.append(transforms.Lambda(lambda img: __scale_yarflam(img, opt.yarflam_img_wh, method)))
if 'zoom' in opt.preprocess:
if params is None:
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
else:
transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
if 'crop' in opt.preprocess:
if params is None or 'crop_pos' not in params:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if 'patch' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
if 'trim' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
# if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
if not opt.no_flip:
if params is None or 'flip' not in params:
transform_list.append(transforms.RandomHorizontalFlip())
elif 'flip' in params:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if convert:
transform_list += [transforms.ToTensor()]
if grayscale:
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if h == oh and w == ow:
return img
return img.resize((w, h), method)
def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
if factor is None:
zoom_level = np.random.uniform(0.8, 1.0, size=[2])
else:
zoom_level = (factor[0], factor[1])
iw, ih = img.size
zoomw = max(crop_width, iw * zoom_level[0])
zoomh = max(crop_width, ih * zoom_level[1])
img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
return img
def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
ow, oh = img.size
shortside = min(ow, oh)
if shortside >= target_width:
return img
else:
scale = target_width / shortside
return img.resize((round(ow * scale), round(oh * scale)), method)
def __trim(img, trim_width):
ow, oh = img.size
if ow > trim_width:
xstart = np.random.randint(ow - trim_width)
xend = xstart + trim_width
else:
xstart = 0
xend = ow
if oh > trim_width:
ystart = np.random.randint(oh - trim_width)
yend = ystart + trim_width
else:
ystart = 0
yend = oh
return img.crop((xstart, ystart, xend, yend))
def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
ow, oh = img.size
if ow == target_width and oh >= crop_width:
return img
w = target_width
h = int(max(target_width * oh / ow, crop_width))
return img.resize((w, h), method)
def __scale_yarflam(img, target_wh, method=Image.BICUBIC):
ow, oh = img.size
if max(ow, oh) <= target_wh:
return img
if ow > target_wh and oh > target_wh:
ratio = target_wh / max(ow, oh)
w, h = int(ow * ratio), int(oh * ratio)
elif ow > target_wh:
w, h = target_wh, int((oh / ow) * target_wh)
else:
w, h = int((ow / oh) * target_wh), target_wh
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __patch(img, index, size):
ow, oh = img.size
nw, nh = ow // size, oh // size
roomx = ow - nw * size
roomy = oh - nh * size
startx = np.random.randint(int(roomx) + 1)
starty = np.random.randint(int(roomy) + 1)
index = index % (nw * nh)
ix = index // nh
iy = index % nh
gridx = startx + ix * size
gridy = starty + iy * size
return img.crop((gridx, gridy, gridx + size, gridy + size))
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img