sleepytaco's picture
initial commit
a4d851a
raw
history blame
No virus
2.05 kB
from PIL import Image
from torchvision import transforms
from skimage import io, transform, util
import numpy as np
import os
"""
Contains utility functions to work with images in tensor and jpg/png forms
"""
def load_image_tensor(image, path=""):
"""
Returns Image as a Pytorch Tensor of shape ((img_size),3).
Values between 0 and 1.
"""
img_size = (256, 256)
# image = io.imread(path)
cropped_image = util.crop(image, ((0, 0), (0, image.shape[1] - image.shape[0]), (0, 0)))
resized_image = (transform.resize(image=cropped_image, output_shape=img_size, anti_aliasing=True))
to_tensor = transforms.Compose([transforms.ToTensor()])
tensor = to_tensor(resized_image)
# tensor = tensor.permute(1,2,0) # the model expects w, h, 3!
return tensor.float()
def convert_tensor_to_PIL_image(image_tensor):
output_image = image_tensor.numpy().transpose(1, 2, 0)
output_image = np.clip(output_image, 0, 1) * 255
output_image = output_image.astype(np.uint8)
output_image = Image.fromarray(output_image)
return output_image
def save_image_tensor(tensor, output_dir="./", image_name="output.png"):
"""
Saves a 3D tensor as an image.
"""
output_image = tensor.numpy().transpose(1, 2, 0)
output_image = np.clip(output_image, 0, 1) * 255
output_image = output_image.astype(np.uint8)
output_image = Image.fromarray(output_image)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
output_image.save(output_dir + image_name)
return output_image
def display_image_tensor(tensor):
"""
Displays the passed in 3D image tensor
"""
output_image = tensor.numpy().transpose(1, 2, 0)
output_image = np.clip(output_image, 0, 1) * 255
output_image = output_image.astype(np.uint8)
output_image = Image.fromarray(output_image)
output_image.show()
def get_grayscale(tensor):
"""
Converts a 3D image tensor to greyscale
"""
greyscale_transform = transforms.Grayscale()
return greyscale_transform(tensor)