Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.utils.spectral_norm as SpectralNorm | |
def GroupNorm(in_channels): | |
ec = 32 | |
assert in_channels % ec == 0 | |
return torch.nn.GroupNorm(num_groups=in_channels//32, num_channels=in_channels, eps=1e-6, affine=True) | |
def swish(x): | |
return x*torch.sigmoid(x) | |
class ResTextBlockV2(nn.Module): | |
def __init__(self, in_channels, out_channels=None): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = in_channels if out_channels is None else out_channels | |
self.norm1 = GroupNorm(in_channels) | |
self.conv1 = SpectralNorm(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)) | |
self.norm2 = GroupNorm(out_channels) | |
self.conv2 = SpectralNorm(nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)) | |
if self.in_channels != self.out_channels: | |
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
def forward(self, x_in): | |
x = x_in | |
x = self.norm1(x) | |
x = swish(x) | |
x = self.conv1(x) | |
x = self.norm2(x) | |
x = swish(x) | |
x = self.conv2(x) | |
if self.in_channels != self.out_channels: | |
x_in = self.conv_out(x_in) | |
return x + x_in | |
def calc_mean_std_4D(feat, eps=1e-6): | |
size = feat.size() | |
assert len(size) == 4, 'The input feature should be 4D tensor.' | |
b, c = size[:2] | |
feat_var = feat.view(b, c, -1).var(dim=2) + eps | |
feat_std = feat_var.sqrt().view(b, c, 1, 1) | |
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) | |
return feat_mean, feat_std | |
def adaptive_instance_normalization(prior_feat, lq_feat): | |
size = prior_feat.size() | |
lq_mean, lq_std = calc_mean_std_4D(lq_feat) | |
prior_mean, prior_std = calc_mean_std_4D(prior_feat) | |
normalized_feat = (prior_feat - prior_mean.expand(size)) / prior_std.expand(size) | |
return normalized_feat * lq_std.expand(size) + lq_mean.expand(size) | |
def network_param(net): | |
num_params = 0 | |
for param in net.parameters(): | |
num_params += param.numel() | |
return num_params / 1e6 |