| | import torch |
| | import torch.nn as nn |
| | from torch.nn.utils import spectral_norm |
| | import torch.nn.functional as F |
| |
|
| | import random |
| |
|
| | seq = nn.Sequential |
| |
|
| | def weights_init(m): |
| | classname = m.__class__.__name__ |
| | if classname.find('Conv') != -1: |
| | try: |
| | m.weight.data.normal_(0.0, 0.02) |
| | except: |
| | pass |
| | elif classname.find('BatchNorm') != -1: |
| | m.weight.data.normal_(1.0, 0.02) |
| | m.bias.data.fill_(0) |
| |
|
| | def conv2d(*args, **kwargs): |
| | return spectral_norm(nn.Conv2d(*args, **kwargs)) |
| |
|
| | def convTranspose2d(*args, **kwargs): |
| | return spectral_norm(nn.ConvTranspose2d(*args, **kwargs)) |
| |
|
| | def batchNorm2d(*args, **kwargs): |
| | return nn.BatchNorm2d(*args, **kwargs) |
| |
|
| | def linear(*args, **kwargs): |
| | return spectral_norm(nn.Linear(*args, **kwargs)) |
| |
|
| | class PixelNorm(nn.Module): |
| | def forward(self, input): |
| | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) |
| |
|
| | class Reshape(nn.Module): |
| | def __init__(self, shape): |
| | super().__init__() |
| | self.target_shape = shape |
| |
|
| | def forward(self, feat): |
| | batch = feat.shape[0] |
| | return feat.view(batch, *self.target_shape) |
| |
|
| |
|
| | class GLU(nn.Module): |
| | def forward(self, x): |
| | nc = x.size(1) |
| | assert nc % 2 == 0, 'channels dont divide 2!' |
| | nc = int(nc/2) |
| | return x[:, :nc] * torch.sigmoid(x[:, nc:]) |
| |
|
| |
|
| | class NoiseInjection(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | self.weight = nn.Parameter(torch.zeros(1), requires_grad=True) |
| |
|
| | def forward(self, feat, noise=None): |
| | if noise is None: |
| | batch, _, height, width = feat.shape |
| | noise = torch.randn(batch, 1, height, width).to(feat.device) |
| |
|
| | return feat + self.weight * noise |
| |
|
| |
|
| | class Swish(nn.Module): |
| | def forward(self, feat): |
| | return feat * torch.sigmoid(feat) |
| |
|
| |
|
| | class SEBlock(nn.Module): |
| | def __init__(self, ch_in, ch_out): |
| | super().__init__() |
| |
|
| | self.main = nn.Sequential( nn.AdaptiveAvgPool2d(4), |
| | conv2d(ch_in, ch_out, 4, 1, 0, bias=False), Swish(), |
| | conv2d(ch_out, ch_out, 1, 1, 0, bias=False), nn.Sigmoid() ) |
| |
|
| | def forward(self, feat_small, feat_big): |
| | return feat_big * self.main(feat_small) |
| |
|
| |
|
| | class InitLayer(nn.Module): |
| | def __init__(self, nz, channel): |
| | super().__init__() |
| |
|
| | self.init = nn.Sequential( |
| | convTranspose2d(nz, channel*2, 4, 1, 0, bias=False), |
| | batchNorm2d(channel*2), GLU() ) |
| |
|
| | def forward(self, noise): |
| | noise = noise.view(noise.shape[0], -1, 1, 1) |
| | return self.init(noise) |
| |
|
| |
|
| | def UpBlock(in_planes, out_planes): |
| | block = nn.Sequential( |
| | nn.Upsample(scale_factor=2, mode='nearest'), |
| | conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), |
| | |
| | batchNorm2d(out_planes*2), GLU()) |
| | return block |
| |
|
| |
|
| | def UpBlockComp(in_planes, out_planes): |
| | block = nn.Sequential( |
| | nn.Upsample(scale_factor=2, mode='nearest'), |
| | conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), |
| | |
| | NoiseInjection(), |
| | batchNorm2d(out_planes*2), GLU(), |
| | conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False), |
| | NoiseInjection(), |
| | batchNorm2d(out_planes*2), GLU() |
| | ) |
| | return block |
| |
|
| |
|
| | class Generator(nn.Module): |
| | def __init__(self, ngf=64, nz=100, nc=3, im_size=1024): |
| | super(Generator, self).__init__() |
| |
|
| | nfc_multi = {4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 512:0.25, 1024:0.125} |
| | nfc = {} |
| | for k, v in nfc_multi.items(): |
| | nfc[k] = int(v*ngf) |
| |
|
| | self.im_size = im_size |
| |
|
| | self.init = InitLayer(nz, channel=nfc[4]) |
| | |
| | self.feat_8 = UpBlockComp(nfc[4], nfc[8]) |
| | self.feat_16 = UpBlock(nfc[8], nfc[16]) |
| | self.feat_32 = UpBlockComp(nfc[16], nfc[32]) |
| | self.feat_64 = UpBlock(nfc[32], nfc[64]) |
| | self.feat_128 = UpBlockComp(nfc[64], nfc[128]) |
| | self.feat_256 = UpBlock(nfc[128], nfc[256]) |
| |
|
| | self.se_64 = SEBlock(nfc[4], nfc[64]) |
| | self.se_128 = SEBlock(nfc[8], nfc[128]) |
| | self.se_256 = SEBlock(nfc[16], nfc[256]) |
| |
|
| | self.to_128 = conv2d(nfc[128], nc, 1, 1, 0, bias=False) |
| | self.to_big = conv2d(nfc[im_size], nc, 3, 1, 1, bias=False) |
| | |
| | if im_size > 256: |
| | self.feat_512 = UpBlockComp(nfc[256], nfc[512]) |
| | self.se_512 = SEBlock(nfc[32], nfc[512]) |
| | if im_size > 512: |
| | self.feat_1024 = UpBlock(nfc[512], nfc[1024]) |
| | |
| | def forward(self, input): |
| | |
| | feat_4 = self.init(input) |
| | feat_8 = self.feat_8(feat_4) |
| | feat_16 = self.feat_16(feat_8) |
| | feat_32 = self.feat_32(feat_16) |
| |
|
| | feat_64 = self.se_64( feat_4, self.feat_64(feat_32) ) |
| |
|
| | feat_128 = self.se_128( feat_8, self.feat_128(feat_64) ) |
| |
|
| | feat_256 = self.se_256( feat_16, self.feat_256(feat_128) ) |
| |
|
| | if self.im_size == 256: |
| | return [self.to_big(feat_256), self.to_128(feat_128)] |
| | |
| | feat_512 = self.se_512( feat_32, self.feat_512(feat_256) ) |
| | if self.im_size == 512: |
| | return [self.to_big(feat_512), self.to_128(feat_128)] |
| |
|
| | feat_1024 = self.feat_1024(feat_512) |
| |
|
| | im_128 = torch.tanh(self.to_128(feat_128)) |
| | im_1024 = torch.tanh(self.to_big(feat_1024)) |
| |
|
| | return [im_1024, im_128] |
| |
|
| |
|
| | class DownBlock(nn.Module): |
| | def __init__(self, in_planes, out_planes): |
| | super(DownBlock, self).__init__() |
| |
|
| | self.main = nn.Sequential( |
| | conv2d(in_planes, out_planes, 4, 2, 1, bias=False), |
| | batchNorm2d(out_planes), nn.LeakyReLU(0.2, inplace=True), |
| | ) |
| |
|
| | def forward(self, feat): |
| | return self.main(feat) |
| |
|
| |
|
| | class DownBlockComp(nn.Module): |
| | def __init__(self, in_planes, out_planes): |
| | super(DownBlockComp, self).__init__() |
| |
|
| | self.main = nn.Sequential( |
| | conv2d(in_planes, out_planes, 4, 2, 1, bias=False), |
| | batchNorm2d(out_planes), nn.LeakyReLU(0.2, inplace=True), |
| | conv2d(out_planes, out_planes, 3, 1, 1, bias=False), |
| | batchNorm2d(out_planes), nn.LeakyReLU(0.2) |
| | ) |
| |
|
| | self.direct = nn.Sequential( |
| | nn.AvgPool2d(2, 2), |
| | conv2d(in_planes, out_planes, 1, 1, 0, bias=False), |
| | batchNorm2d(out_planes), nn.LeakyReLU(0.2)) |
| |
|
| | def forward(self, feat): |
| | return (self.main(feat) + self.direct(feat)) / 2 |
| |
|
| |
|
| | class Discriminator(nn.Module): |
| | def __init__(self, ndf=64, nc=3, im_size=512): |
| | super(Discriminator, self).__init__() |
| | self.ndf = ndf |
| | self.im_size = im_size |
| |
|
| | nfc_multi = {4:16, 8:16, 16:8, 32:4, 64:2, 128:1, 256:0.5, 512:0.25, 1024:0.125} |
| | nfc = {} |
| | for k, v in nfc_multi.items(): |
| | nfc[k] = int(v*ndf) |
| |
|
| | if im_size == 1024: |
| | self.down_from_big = nn.Sequential( |
| | conv2d(nc, nfc[1024], 4, 2, 1, bias=False), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | conv2d(nfc[1024], nfc[512], 4, 2, 1, bias=False), |
| | batchNorm2d(nfc[512]), |
| | nn.LeakyReLU(0.2, inplace=True)) |
| | elif im_size == 512: |
| | self.down_from_big = nn.Sequential( |
| | conv2d(nc, nfc[512], 4, 2, 1, bias=False), |
| | nn.LeakyReLU(0.2, inplace=True) ) |
| | elif im_size == 256: |
| | self.down_from_big = nn.Sequential( |
| | conv2d(nc, nfc[512], 3, 1, 1, bias=False), |
| | nn.LeakyReLU(0.2, inplace=True) ) |
| |
|
| | self.down_4 = DownBlockComp(nfc[512], nfc[256]) |
| | self.down_8 = DownBlockComp(nfc[256], nfc[128]) |
| | self.down_16 = DownBlockComp(nfc[128], nfc[64]) |
| | self.down_32 = DownBlockComp(nfc[64], nfc[32]) |
| | self.down_64 = DownBlockComp(nfc[32], nfc[16]) |
| |
|
| | self.rf_big = nn.Sequential( |
| | conv2d(nfc[16] , nfc[8], 1, 1, 0, bias=False), |
| | batchNorm2d(nfc[8]), nn.LeakyReLU(0.2, inplace=True), |
| | conv2d(nfc[8], 1, 4, 1, 0, bias=False)) |
| |
|
| | self.se_2_16 = SEBlock(nfc[512], nfc[64]) |
| | self.se_4_32 = SEBlock(nfc[256], nfc[32]) |
| | self.se_8_64 = SEBlock(nfc[128], nfc[16]) |
| | |
| | self.down_from_small = nn.Sequential( |
| | conv2d(nc, nfc[256], 4, 2, 1, bias=False), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | DownBlock(nfc[256], nfc[128]), |
| | DownBlock(nfc[128], nfc[64]), |
| | DownBlock(nfc[64], nfc[32]), ) |
| |
|
| | self.rf_small = conv2d(nfc[32], 1, 4, 1, 0, bias=False) |
| |
|
| | self.decoder_big = SimpleDecoder(nfc[16], nc) |
| | self.decoder_part = SimpleDecoder(nfc[32], nc) |
| | self.decoder_small = SimpleDecoder(nfc[32], nc) |
| | |
| | def forward(self, imgs, label, part=None): |
| | if type(imgs) is not list: |
| | imgs = [F.interpolate(imgs, size=self.im_size), F.interpolate(imgs, size=128)] |
| |
|
| | feat_2 = self.down_from_big(imgs[0]) |
| | feat_4 = self.down_4(feat_2) |
| | feat_8 = self.down_8(feat_4) |
| | |
| | feat_16 = self.down_16(feat_8) |
| | feat_16 = self.se_2_16(feat_2, feat_16) |
| |
|
| | feat_32 = self.down_32(feat_16) |
| | feat_32 = self.se_4_32(feat_4, feat_32) |
| | |
| | feat_last = self.down_64(feat_32) |
| | feat_last = self.se_8_64(feat_8, feat_last) |
| |
|
| | |
| | |
| | rf_0 = self.rf_big(feat_last).view(-1) |
| |
|
| | feat_small = self.down_from_small(imgs[1]) |
| | |
| | rf_1 = self.rf_small(feat_small).view(-1) |
| |
|
| | if label=='real': |
| | rec_img_big = self.decoder_big(feat_last) |
| | rec_img_small = self.decoder_small(feat_small) |
| |
|
| | assert part is not None |
| | rec_img_part = None |
| | if part==0: |
| | rec_img_part = self.decoder_part(feat_32[:,:,:8,:8]) |
| | if part==1: |
| | rec_img_part = self.decoder_part(feat_32[:,:,:8,8:]) |
| | if part==2: |
| | rec_img_part = self.decoder_part(feat_32[:,:,8:,:8]) |
| | if part==3: |
| | rec_img_part = self.decoder_part(feat_32[:,:,8:,8:]) |
| |
|
| | return torch.cat([rf_0, rf_1]) , [rec_img_big, rec_img_small, rec_img_part] |
| |
|
| | return torch.cat([rf_0, rf_1]) |
| |
|
| |
|
| | class SimpleDecoder(nn.Module): |
| | """docstring for CAN_SimpleDecoder""" |
| | def __init__(self, nfc_in=64, nc=3): |
| | super(SimpleDecoder, self).__init__() |
| |
|
| | nfc_multi = {4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, 512:0.25, 1024:0.125} |
| | nfc = {} |
| | for k, v in nfc_multi.items(): |
| | nfc[k] = int(v*32) |
| |
|
| | def upBlock(in_planes, out_planes): |
| | block = nn.Sequential( |
| | nn.Upsample(scale_factor=2, mode='nearest'), |
| | conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False), |
| | batchNorm2d(out_planes*2), GLU()) |
| | return block |
| |
|
| | self.main = nn.Sequential( nn.AdaptiveAvgPool2d(8), |
| | upBlock(nfc_in, nfc[16]) , |
| | upBlock(nfc[16], nfc[32]), |
| | upBlock(nfc[32], nfc[64]), |
| | upBlock(nfc[64], nfc[128]), |
| | conv2d(nfc[128], nc, 3, 1, 1, bias=False), |
| | nn.Tanh() ) |
| |
|
| | def forward(self, input): |
| | |
| | return self.main(input) |
| |
|
| | from random import randint |
| | def random_crop(image, size): |
| | h, w = image.shape[2:] |
| | ch = randint(0, h-size-1) |
| | cw = randint(0, w-size-1) |
| | return image[:,:,ch:ch+size,cw:cw+size] |
| |
|
| | class TextureDiscriminator(nn.Module): |
| | def __init__(self, ndf=64, nc=3, im_size=512): |
| | super(TextureDiscriminator, self).__init__() |
| | self.ndf = ndf |
| | self.im_size = im_size |
| |
|
| | nfc_multi = {4:16, 8:8, 16:8, 32:4, 64:2, 128:1, 256:0.5, 512:0.25, 1024:0.125} |
| | nfc = {} |
| | for k, v in nfc_multi.items(): |
| | nfc[k] = int(v*ndf) |
| |
|
| | self.down_from_small = nn.Sequential( |
| | conv2d(nc, nfc[256], 4, 2, 1, bias=False), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | DownBlock(nfc[256], nfc[128]), |
| | DownBlock(nfc[128], nfc[64]), |
| | DownBlock(nfc[64], nfc[32]), ) |
| | self.rf_small = nn.Sequential( |
| | conv2d(nfc[16], 1, 4, 1, 0, bias=False)) |
| |
|
| | self.decoder_small = SimpleDecoder(nfc[32], nc) |
| | |
| | def forward(self, img, label): |
| | img = random_crop(img, size=128) |
| |
|
| | feat_small = self.down_from_small(img) |
| | rf = self.rf_small(feat_small).view(-1) |
| | |
| | if label=='real': |
| | rec_img_small = self.decoder_small(feat_small) |
| |
|
| | return rf, rec_img_small, img |
| |
|
| | return rf |