File size: 1,200 Bytes
12b5a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from utils import *


class AdaIN(torch.nn.Module):
    def __init__(self):
        super(AdaIN, self).__init__()
        # initialize instance normalization function
        # this is the basis of our AdaIN layer, it follows an equation similar to a z-score
        # (x - mu)/sigma
        self.instance_norm = torch.nn.InstanceNorm2d(3)
    
    # forward method for our layer
    # x would be the content input and y would be the style input
    # both x and y are tensors
    def forward(self, x, y):
        # size is shaped (N, num_channels, Height, Width)
        x_size = x.size()
        
        # we do not need these since they will be calculated by the instance normalization function
        #x_mean, x_std = mean_and_std_of_image(x)
        y_mean, y_std = mean_and_std_of_image(y)

        x_norm = self.instance_norm(x)

        
        print(x_norm.size())
        # expand size of tensors so that there are no shape errors when performing AdaIN operation
        # if not self.training:
        #     x_norm = x_norm.view(*x_norm.shape, 1)

        x_size = x_norm.size()
        print(x_size)
        return y_std.expand(x_size) * x_norm + y_mean.expand(x_size)