File size: 2,263 Bytes
087921f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
from .G1 import G1
from .G2 import G2
from .G3 import G3
from .zencoder import DeeperZencoder

##########################################
class Generator(nn.Module):
    def __init__(self, cfg):
        super(Generator, self).__init__()

        self.G1 = G1(cfg)
        self.G2 = G2(cfg)
        self.G3 = G3(cfg)
        self.zencoder = DeeperZencoder(cfg)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear')  # 'nearest', 'bilinear'

    def forward(self, gt, input, segmap, inst_map, mask, cached_codes=None):
        gt_G1 = F.interpolate(gt, size=(64, 64), mode='bilinear')
        gt_G2 = F.interpolate(gt, size=(128, 128), mode='bilinear')
        gt_G3 = F.interpolate(gt, size=(256, 256), mode='bilinear')

        style_codes = self.zencoder(input, segmap, 1.0 - mask, inst_map, cached_codes=cached_codes)

        input_G1 =  F.interpolate(input, size=(64, 64), mode='bilinear')
        fake_G1 = self.G1(input_G1, segmap, mask, style_codes)
        mask_fake_G1 = self.masked_fake(gt_G1, fake_G1, mask)
        input_G2 = self.next_img(gt_G2, fake_G1, mask)

        fake_G2 = self.G2(input_G2, segmap, mask, style_codes)
        mask_fake_G2 = self.masked_fake(gt_G2, fake_G2, mask)
        input_G3 = self.next_img(gt_G3, fake_G2, mask)

        fake_G3 = self.G3(input_G3, segmap, mask, style_codes)
        mask_fake_G3 = self.masked_fake(gt_G3, fake_G3, mask)

        return [gt_G1, gt_G2, gt_G3], [input_G1, input_G2, input_G3], [mask_fake_G1, mask_fake_G2, mask_fake_G3], [fake_G1, fake_G2, fake_G3]

    def get_style_codes(self, input, segmap, inst_map, mask):
        return self.zencoder.generate_style_codes(input, segmap, 1.0 - mask, inst_map)

    def masked_fake(self, img, fake, mask):
        mask = F.interpolate(mask, size=fake.size()[2:], mode='nearest')
        combined = mask * fake + (1. - mask) * img
        return combined

    def next_img(self, img, prev_fake, mask):
        fake = self.up(prev_fake)
        mask = F.interpolate(mask, size=fake.size()[2:], mode='nearest')
        combined = mask * fake + (1. - mask) * img

        return combined