File size: 2,054 Bytes
a4d851a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from PIL import Image
from torchvision import transforms
from skimage import io, transform, util
import numpy as np
import os

"""
Contains utility functions to work with images in tensor and jpg/png forms
"""


def load_image_tensor(image, path=""):
    """
    Returns Image as a Pytorch Tensor of shape ((img_size),3).
    Values between 0 and 1.
    """
    img_size = (256, 256)
    # image = io.imread(path)
    cropped_image = util.crop(image, ((0, 0), (0, image.shape[1] - image.shape[0]), (0, 0)))
    resized_image = (transform.resize(image=cropped_image, output_shape=img_size, anti_aliasing=True))
    to_tensor = transforms.Compose([transforms.ToTensor()])
    tensor = to_tensor(resized_image)
    # tensor = tensor.permute(1,2,0)  # the model expects w, h, 3!
    return tensor.float()


def convert_tensor_to_PIL_image(image_tensor):
    output_image = image_tensor.numpy().transpose(1, 2, 0)
    output_image = np.clip(output_image, 0, 1) * 255
    output_image = output_image.astype(np.uint8)
    output_image = Image.fromarray(output_image)
    return output_image

def save_image_tensor(tensor, output_dir="./", image_name="output.png"):
    """
    Saves a 3D tensor as an image.
    """
    output_image = tensor.numpy().transpose(1, 2, 0)
    output_image = np.clip(output_image, 0, 1) * 255
    output_image = output_image.astype(np.uint8)
    output_image = Image.fromarray(output_image)

    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_image.save(output_dir + image_name)

    return output_image


def display_image_tensor(tensor):
    """
    Displays the passed in 3D image tensor
    """
    output_image = tensor.numpy().transpose(1, 2, 0)
    output_image = np.clip(output_image, 0, 1) * 255
    output_image = output_image.astype(np.uint8)
    output_image = Image.fromarray(output_image)
    output_image.show()


def get_grayscale(tensor):
    """
    Converts a 3D image tensor to greyscale
    """
    greyscale_transform = transforms.Grayscale()
    return greyscale_transform(tensor)