Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,163 Bytes
981b0ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
import torch.nn as nn
import torch.nn.utils.spectral_norm as SpectralNorm
def GroupNorm(in_channels):
ec = 32
assert in_channels % ec == 0
return torch.nn.GroupNorm(num_groups=in_channels//32, num_channels=in_channels, eps=1e-6, affine=True)
def swish(x):
return x*torch.sigmoid(x)
class ResTextBlockV2(nn.Module):
def __init__(self, in_channels, out_channels=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = GroupNorm(in_channels)
self.conv1 = SpectralNorm(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
self.norm2 = GroupNorm(out_channels)
self.conv2 = SpectralNorm(nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1))
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
def calc_mean_std_4D(feat, eps=1e-6):
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(prior_feat, lq_feat):
size = prior_feat.size()
lq_mean, lq_std = calc_mean_std_4D(lq_feat)
prior_mean, prior_std = calc_mean_std_4D(prior_feat)
normalized_feat = (prior_feat - prior_mean.expand(size)) / prior_std.expand(size)
return normalized_feat * lq_std.expand(size) + lq_mean.expand(size)
def network_param(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
return num_params / 1e6 |