|
import os |
|
from PIL import Image, ImageFile |
|
import torch |
|
from torch.utils.data import Dataset |
|
import torchvision.transforms as transforms |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from glob import glob |
|
|
|
def adaptive_instance_normalization(x, y, eps=1e-5): |
|
""" |
|
Adaptive Instance Normalization. Perform neural style transfer given content image x |
|
and style image y. |
|
|
|
Args: |
|
x (torch.FloatTensor): Content image tensor |
|
y (torch.FloatTensor): Style image tensor |
|
eps (float, default=1e-5): Small value to avoid zero division |
|
|
|
Return: |
|
output (torch.FloatTensor): AdaIN style transferred output |
|
""" |
|
|
|
mu_x = torch.mean(x, dim=[2, 3]) |
|
mu_y = torch.mean(y, dim=[2, 3]) |
|
mu_x = mu_x.unsqueeze(-1).unsqueeze(-1) |
|
mu_y = mu_y.unsqueeze(-1).unsqueeze(-1) |
|
|
|
sigma_x = torch.std(x, dim=[2, 3]) |
|
sigma_y = torch.std(y, dim=[2, 3]) |
|
sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1) + eps |
|
sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1) + eps |
|
|
|
return (x - mu_x) / sigma_x * sigma_y + mu_y |
|
|
|
def transform(size): |
|
""" |
|
Image preprocess transformation. Resize image and convert to tensor. |
|
|
|
Args: |
|
size (int): Resize image size |
|
|
|
Return: |
|
output (torchvision.transforms): Composition of torchvision.transforms steps |
|
""" |
|
|
|
t = [] |
|
t.append(transforms.Resize(size)) |
|
t.append(transforms.ToTensor()) |
|
t = transforms.Compose(t) |
|
return t |
|
|
|
def grid_image(row, col, images, height=6, width=6, save_pth='grid.png'): |
|
""" |
|
Generate and save an image that contains row x col grids of images. |
|
|
|
Args: |
|
row (int): number of rows |
|
col (int): number of columns |
|
images (list of PIL image): list of images. |
|
height (int) : height of each image (inch) |
|
width (int) : width of eac image (inch) |
|
save_pth (str): save file path |
|
""" |
|
|
|
width = col * width |
|
height = row * height |
|
plt.figure(figsize=(width, height)) |
|
for i, image in enumerate(images): |
|
plt.subplot(row, col, i+1) |
|
plt.imshow(image) |
|
plt.axis('off') |
|
plt.subplots_adjust(wspace=0.01, hspace=0.01) |
|
plt.savefig(save_pth) |
|
|
|
|
|
def linear_histogram_matching(content_tensor, style_tensor): |
|
""" |
|
Given content_tensor and style_tensor, transform style_tensor histogram to that of content_tensor. |
|
|
|
Args: |
|
content_tensor (torch.FloatTensor): Content image |
|
style_tensor (torch.FloatTensor): Style Image |
|
|
|
Return: |
|
style_tensor (torch.FloatTensor): histogram matched Style Image |
|
""" |
|
|
|
for b in range(len(content_tensor)): |
|
std_ct = [] |
|
std_st = [] |
|
mean_ct = [] |
|
mean_st = [] |
|
|
|
for c in range(len(content_tensor[b])): |
|
std_ct.append(torch.var(content_tensor[b][c],unbiased = False)) |
|
mean_ct.append(torch.mean(content_tensor[b][c])) |
|
std_st.append(torch.var(style_tensor[b][c],unbiased = False)) |
|
mean_st.append(torch.mean(style_tensor[b][c])) |
|
style_tensor[b][c] = (style_tensor[b][c] - mean_st[c]) * std_ct[c] / std_st[c] + mean_ct[c] |
|
return style_tensor |
|
|
|
|
|
class TrainSet(Dataset): |
|
""" |
|
Build Training dataset |
|
""" |
|
def __init__(self, content_dir, style_dir, crop_size = 256): |
|
super().__init__() |
|
|
|
self.content_files = [Path(f) for f in glob(content_dir+'/*')] |
|
self.style_files = [Path(f) for f in glob(style_dir+'/*')] |
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.RandomCrop(crop_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) |
|
]) |
|
|
|
Image.MAX_IMAGE_PIXELS = None |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
def __len__(self): |
|
return min(len(self.style_files), len(self.content_files)) |
|
|
|
def __getitem__(self, index): |
|
content_img = Image.open(self.content_files[index]).convert('RGB') |
|
style_img = Image.open(self.style_files[index]).convert('RGB') |
|
|
|
content_sample = self.transform(content_img) |
|
style_sample = self.transform(style_img) |
|
|
|
return content_sample, style_sample |
|
|
|
class Range(object): |
|
""" |
|
Helper class for input argument range restriction |
|
""" |
|
def __init__(self, start, end): |
|
self.start = start |
|
self.end = end |
|
def __eq__(self, other): |
|
return self.start <= other <= self.end |