|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class AdaptiveInstanceNorm1d(nn.Module): |
|
def __init__(self, num_features, eps=1e-5, momentum=0.1): |
|
super(AdaptiveInstanceNorm1d, self).__init__() |
|
self.num_features = num_features |
|
self.eps = eps |
|
self.momentum = momentum |
|
self.weight = None |
|
self.bias = None |
|
self.register_buffer('running_mean', torch.zeros(num_features)) |
|
self.register_buffer('running_var', torch.ones(num_features)) |
|
|
|
def forward(self, x, direct_weighting=False, no_std=False): |
|
assert self.weight is not None and \ |
|
self.bias is not None, "Please assign AdaIN weight first" |
|
|
|
x = x.permute(1,2,0) |
|
|
|
b, c = x.size(0), x.size(1) |
|
running_mean = self.running_mean.repeat(b) |
|
running_var = self.running_var.repeat(b) |
|
|
|
|
|
if direct_weighting: |
|
x_reshaped = x.contiguous().view(b * c) |
|
if no_std: |
|
out = x_reshaped + self.bias |
|
else: |
|
out = x_reshaped.mul(self.weight) + self.bias |
|
out = out.view(b, c, *x.size()[2:]) |
|
else: |
|
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) |
|
out = F.batch_norm( |
|
x_reshaped, running_mean, running_var, self.weight, self.bias, |
|
True, self.momentum, self.eps) |
|
out = out.view(b, c, *x.size()[2:]) |
|
|
|
|
|
out = out.permute(2,0,1) |
|
return out |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(' + str(self.num_features) + ')' |
|
|
|
def assign_adain_params(adain_params, model): |
|
|
|
for m in model.modules(): |
|
if m.__class__.__name__ == "AdaptiveInstanceNorm1d": |
|
mean = adain_params[: , : m.num_features] |
|
std = adain_params[: , m.num_features: 2 * m.num_features] |
|
m.bias = mean.contiguous().view(-1) |
|
m.weight = std.contiguous().view(-1) |
|
if adain_params.size(1) > 2 * m.num_features: |
|
adain_params = adain_params[: , 2 * m.num_features:] |
|
|
|
|
|
def get_num_adain_params(model): |
|
|
|
num_adain_params = 0 |
|
for m in model.modules(): |
|
if m.__class__.__name__ == "AdaptiveInstanceNorm1d": |
|
num_adain_params += 2 * m.num_features |
|
return num_adain_params |
|
|