File size: 4,369 Bytes
			
			a9d81c5  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140  | 
								"""
Utility functions for the DIE demo.
"""
import torch
from PIL import Image
from torch import Tensor
from torchvision import transforms
def resize_image(
    image: Image.Image,
    max_size: int = 1024
) -> Image.Image:
    """
    Resizing images by keeping the ratios
    :param image: PIL image
    :param max_size: size of the new image larger side
    :return: the resized PIL image
    """
    # extracting size
    width, height = image.size
    # checking which side is larger
    height_larger = True if height >= width else False
    # reshaping based on the larger side
    if height_larger:
        height_new = max_size
        width_new = round((height_new / height) * width)
    else:
        width_new = max_size
        height_new = round((width_new / width) * height)
    return image.resize((width_new, height_new))
def make_image_square(
    image: Image.Image,
    image_size: int = 1024
) -> Image.Image:
    """
    Making the input image a square
    :param image: PIL image
    :param image_size: defines the size of the square image
    :return: the square-sized PIL image
    """
    if max(image.size) > image_size:
        image_size = max(image.size)
    # creating a new square image
    if image.mode == 'L':
        image_square = Image.new(image.mode, (image_size, image_size), (255,))
    elif image.mode == 'RGB':
        image_square = Image.new(image.mode, (image_size, image_size), (255, 255, 255))
    else:
        raise NotImplementedError("Not implemented image mode.")
    # copying the original content onto the blank image
    image_square.paste(image, (0, 0))
    return image_square
def cast_pil_image_to_torch_tensor_with_4_channel_dim(
    image: Image.Image,
    device: str | None = None
) -> Tensor:
    """
    Casting PIL image to torch tensor.
    Adding the grayscale image of the original RGB image as a 4th channel dimension.
    :param image: input image
    :param device: cuda device
    :return: torch tensor (4 channel dim)
    """
    # PIL image to torch tensor transformation
    transform = transforms.Compose([transforms.PILToTensor()])
    # creating gray image
    image_gray = image.convert('L')
    # casting PIL images to torch tensor with normalization
    image_tensor = transform(image.convert('RGB')).to(torch.float32) / 255.0
    image_gray_tensor = transform(image_gray).to(torch.float32) / 255.0
    # concatenating gray channel to RGB channel
    final_image_tensor = torch.cat((image_tensor, image_gray_tensor), dim=0)
    # moving tensor to gpu if required
    if device is not None:
        final_image_tensor = final_image_tensor.to(device)
    return final_image_tensor
def remove_square_padding(
    original_image: Image.Image | Tensor,
    square_image: Image.Image | Tensor,
    resize_back_to_original: bool = False
):
    """
    Removing the square padding added to the original image to make square.
    :param original_image: the image with the original size
    :param square_image: the image with the square size
    :param resize_back_to_original: defines if we want to resize the square image back to the original size
    :return: square image with the original size ratio
    """
    if isinstance(original_image, Image.Image):
        original_width, original_height = original_image.size
    else:
        original_height, original_width = original_image.shape[:2]
    if isinstance(square_image, Image.Image):
        square_width, square_height = square_image.size
    else:
        square_height, square_width = square_image.shape[:2]
    if original_width > original_height:
        ratio = square_width / original_width
        new_width = square_width
        new_height = int(ratio * original_height)
    else:
        ratio = square_height / original_height
        new_height = square_height
        new_width = int(ratio * original_width)
    # cutting size of the square image to the original ratio
    if isinstance(square_image, Image.Image):
        square_image_with_original_ratio = square_image.crop((0, 0, new_width, new_height))
    else:
        square_image_with_original_ratio = square_image[:new_height, :new_width]
    if resize_back_to_original:
        square_image_with_original_ratio = square_image_with_original_ratio.resize((original_width, original_height))
    return square_image_with_original_ratio
 |