PITI-Synthesis / glide_text2im /image_datasets_sketch.py
tfwang's picture
Update glide_text2im/image_datasets_sketch.py
b3ff61f
import math
import random
from PIL import Image
import blobfile as bf
#from mpi4py import MPI
import numpy as np
from torch.utils.data import DataLoader, Dataset
import os
import torchvision.transforms as transforms
import torch as th
from functools import partial
import cv2
def get_params( size, resize_size, crop_size):
w, h = size
new_h = h
new_w = w
ss, ls = min(w, h), max(w, h) # shortside and longside
width_is_shorter = w == ss
ls = int(resize_size * ls / ss)
ss = resize_size
new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
x = random.randint(0, np.maximum(0, new_w - crop_size))
y = random.randint(0, np.maximum(0, new_h - crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(params, resize_size, crop_size, method=Image.BICUBIC, flip=True, crop = True):
transform_list = []
transform_list.append(transforms.Lambda(lambda img: __scale(img, crop_size, method)))
if flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
return transforms.Compose(transform_list)
def get_tensor(normalize=True, toTensor=True):
transform_list = []
if toTensor:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __scale(img, target_width, method=Image.BICUBIC):
return img.resize((target_width, target_width), method)
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img