hakansivuk's picture
Final commit
087921f
raw
history blame contribute delete
No virus
2.26 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
from .G1 import G1
from .G2 import G2
from .G3 import G3
from .zencoder import DeeperZencoder
##########################################
class Generator(nn.Module):
def __init__(self, cfg):
super(Generator, self).__init__()
self.G1 = G1(cfg)
self.G2 = G2(cfg)
self.G3 = G3(cfg)
self.zencoder = DeeperZencoder(cfg)
self.up = nn.Upsample(scale_factor=2, mode='bilinear') # 'nearest', 'bilinear'
def forward(self, gt, input, segmap, inst_map, mask, cached_codes=None):
gt_G1 = F.interpolate(gt, size=(64, 64), mode='bilinear')
gt_G2 = F.interpolate(gt, size=(128, 128), mode='bilinear')
gt_G3 = F.interpolate(gt, size=(256, 256), mode='bilinear')
style_codes = self.zencoder(input, segmap, 1.0 - mask, inst_map, cached_codes=cached_codes)
input_G1 = F.interpolate(input, size=(64, 64), mode='bilinear')
fake_G1 = self.G1(input_G1, segmap, mask, style_codes)
mask_fake_G1 = self.masked_fake(gt_G1, fake_G1, mask)
input_G2 = self.next_img(gt_G2, fake_G1, mask)
fake_G2 = self.G2(input_G2, segmap, mask, style_codes)
mask_fake_G2 = self.masked_fake(gt_G2, fake_G2, mask)
input_G3 = self.next_img(gt_G3, fake_G2, mask)
fake_G3 = self.G3(input_G3, segmap, mask, style_codes)
mask_fake_G3 = self.masked_fake(gt_G3, fake_G3, mask)
return [gt_G1, gt_G2, gt_G3], [input_G1, input_G2, input_G3], [mask_fake_G1, mask_fake_G2, mask_fake_G3], [fake_G1, fake_G2, fake_G3]
def get_style_codes(self, input, segmap, inst_map, mask):
return self.zencoder.generate_style_codes(input, segmap, 1.0 - mask, inst_map)
def masked_fake(self, img, fake, mask):
mask = F.interpolate(mask, size=fake.size()[2:], mode='nearest')
combined = mask * fake + (1. - mask) * img
return combined
def next_img(self, img, prev_fake, mask):
fake = self.up(prev_fake)
mask = F.interpolate(mask, size=fake.size()[2:], mode='nearest')
combined = mask * fake + (1. - mask) * img
return combined