import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import functools from .G1 import G1 from .G2 import G2 from .G3 import G3 from .zencoder import DeeperZencoder ########################################## class Generator(nn.Module): def __init__(self, cfg): super(Generator, self).__init__() self.G1 = G1(cfg) self.G2 = G2(cfg) self.G3 = G3(cfg) self.zencoder = DeeperZencoder(cfg) self.up = nn.Upsample(scale_factor=2, mode='bilinear') # 'nearest', 'bilinear' def forward(self, gt, input, segmap, inst_map, mask, cached_codes=None): gt_G1 = F.interpolate(gt, size=(64, 64), mode='bilinear') gt_G2 = F.interpolate(gt, size=(128, 128), mode='bilinear') gt_G3 = F.interpolate(gt, size=(256, 256), mode='bilinear') style_codes = self.zencoder(input, segmap, 1.0 - mask, inst_map, cached_codes=cached_codes) input_G1 = F.interpolate(input, size=(64, 64), mode='bilinear') fake_G1 = self.G1(input_G1, segmap, mask, style_codes) mask_fake_G1 = self.masked_fake(gt_G1, fake_G1, mask) input_G2 = self.next_img(gt_G2, fake_G1, mask) fake_G2 = self.G2(input_G2, segmap, mask, style_codes) mask_fake_G2 = self.masked_fake(gt_G2, fake_G2, mask) input_G3 = self.next_img(gt_G3, fake_G2, mask) fake_G3 = self.G3(input_G3, segmap, mask, style_codes) mask_fake_G3 = self.masked_fake(gt_G3, fake_G3, mask) return [gt_G1, gt_G2, gt_G3], [input_G1, input_G2, input_G3], [mask_fake_G1, mask_fake_G2, mask_fake_G3], [fake_G1, fake_G2, fake_G3] def get_style_codes(self, input, segmap, inst_map, mask): return self.zencoder.generate_style_codes(input, segmap, 1.0 - mask, inst_map) def masked_fake(self, img, fake, mask): mask = F.interpolate(mask, size=fake.size()[2:], mode='nearest') combined = mask * fake + (1. - mask) * img return combined def next_img(self, img, prev_fake, mask): fake = self.up(prev_fake) mask = F.interpolate(mask, size=fake.size()[2:], mode='nearest') combined = mask * fake + (1. - mask) * img return combined