import torch | |
from utils import * | |
class AdaIN(torch.nn.Module): | |
def __init__(self): | |
super(AdaIN, self).__init__() | |
# initialize instance normalization function | |
# this is the basis of our AdaIN layer, it follows an equation similar to a z-score | |
# (x - mu)/sigma | |
self.instance_norm = torch.nn.InstanceNorm2d(3) | |
# forward method for our layer | |
# x would be the content input and y would be the style input | |
# both x and y are tensors | |
def forward(self, x, y): | |
# size is shaped (N, num_channels, Height, Width) | |
x_size = x.size() | |
# we do not need these since they will be calculated by the instance normalization function | |
#x_mean, x_std = mean_and_std_of_image(x) | |
y_mean, y_std = mean_and_std_of_image(y) | |
x_norm = self.instance_norm(x) | |
print(x_norm.size()) | |
# expand size of tensors so that there are no shape errors when performing AdaIN operation | |
# if not self.training: | |
# x_norm = x_norm.view(*x_norm.shape, 1) | |
x_size = x_norm.size() | |
print(x_size) | |
return y_std.expand(x_size) * x_norm + y_mean.expand(x_size) | |