|
|
import torch |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
|
|
|
|
|
|
def load_image(image_path, device): |
|
|
|
|
|
image_size = 356 |
|
|
|
|
|
loader = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize((image_size, image_size)), |
|
|
transforms.ToTensor() |
|
|
] |
|
|
) |
|
|
|
|
|
image = Image.open(image_path) |
|
|
image = loader(image).unsqueeze(0) |
|
|
|
|
|
return image.to(device) |
|
|
|
|
|
|
|
|
def tensor_to_image(tensor): |
|
|
tensor = tensor.clone().detach() |
|
|
tensor = tensor.squeeze(0) |
|
|
tensor = torch.clamp(tensor, 0, 1) |
|
|
|
|
|
unloader = transforms.ToPILImage() |
|
|
image = unloader(tensor.cpu()) |
|
|
return image |
|
|
|