import torch import numpy as np from PIL import Image from torchvision.transforms import Compose, Resize, Grayscale, ToTensor, ToPILImage # global variable for the grayscale transform transform_gs = Compose( [Resize((360, 360)), Grayscale(num_output_channels=1), ToTensor()] ) def process_gs_image(image): """ Function to process the grayscale image. """ # Save original size for later use original_size = image.size # (width, height) # Convert the image to grayscale and resize image = transform_gs(image) # Add the batch dimension image = image.unsqueeze(0) # Return both the processed image and original size return image, original_size def inverse_transform_cs(tensor, original_size): """ Function to convert the tensor back to the color image and resize it to its original size. """ # Convert the tensor back to a PIL image to_pil = ToPILImage() pil_image = to_pil(tensor.squeeze(0)) # Remove the batch dimension # Resize the image back to the original size pil_image = pil_image.resize(original_size) return pil_image