File size: 2,020 Bytes
9667e74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import collections
from torch import nn
from NTED.base_module import Encoder, Decoder

from torch.cuda.amp import autocast as autocast

class Generator(nn.Module):
    def __init__(
        self,
        size,
        semantic_dim,
        channels,
        num_labels,
        match_kernels,
        blur_kernel=[1, 3, 3, 1],
    ):
        super().__init__()
        self.size = size
        self.reference_encoder = Encoder(
            size, 3, channels, num_labels, match_kernels, blur_kernel
        )
            
        self.skeleton_encoder = Encoder(
            size, semantic_dim, channels, 
            )

        self.target_image_renderer = Decoder(
            size, channels, num_labels, match_kernels, blur_kernel
        )

    def _cal_temp(self, module):
        return sum(p.numel() for p in module.parameters() if p.requires_grad)

    def forward(
        self,
        source_image,
        skeleton,
        amp_flag=False,
    ):
        if amp_flag:
            with autocast():
                output_dict={}
                recoder = collections.defaultdict(list)
                skeleton_feature = self.skeleton_encoder(skeleton)
                _ = self.reference_encoder(source_image, recoder)
                neural_textures = recoder["neural_textures"]
                output_dict['fake_image'] = self.target_image_renderer(
                    skeleton_feature, neural_textures, recoder
                    )
                output_dict['info'] = recoder
                return output_dict
        else:
            output_dict={}
            recoder = collections.defaultdict(list)
            skeleton_feature = self.skeleton_encoder(skeleton)
            _ = self.reference_encoder(source_image, recoder)
            neural_textures = recoder["neural_textures"]
            output_dict['fake_image'] = self.target_image_renderer(
                skeleton_feature, neural_textures, recoder
                )
            output_dict['info'] = recoder
            return output_dict