File size: 6,225 Bytes
a104d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import kornia
import warnings

from modules.layers.faceshifter.layers import AEI_Net
from modules.layers.faceshifter.hear_layers import Hear_Net
from third_party.arcface import iresnet100, MouthNet

make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))


class FSGenerator(nn.Module):
    def __init__(self,
                 id_ckpt: str = None,
                 id_dim: int = 512,
                 mouth_net_param: dict = None,
                 in_size: int = 256,
                 finetune: bool = False,
                 downup: bool = False,
                 ):
        super(FSGenerator, self).__init__()

        ''' MouthNet '''
        self.use_mouth_net = mouth_net_param.get('use')
        self.mouth_feat_dim = 0
        self.mouth_net = None
        if self.use_mouth_net:
            self.mouth_feat_dim = mouth_net_param.get('feature_dim')
            self.mouth_crop_param = mouth_net_param.get('crop_param')
            mouth_weight_path = make_abs_path(mouth_net_param.get('weight_path'))
            self.mouth_net = MouthNet(
                bisenet=None,
                feature_dim=self.mouth_feat_dim,
                crop_param=self.mouth_crop_param
            )
            self.mouth_net.load_backbone(mouth_weight_path)
            print("[FaceShifter Generator] MouthNet loaded from %s" % mouth_weight_path)
            self.mouth_net.eval()
            self.mouth_net.requires_grad_(False)

        self.G = AEI_Net(c_id=id_dim + self.mouth_feat_dim, finetune=finetune, downup=downup)
        self.iresnet = iresnet100()
        if not id_ckpt is None:
            self.iresnet.load_state_dict(torch.load(id_ckpt, "cpu"))
        else:
            warnings.warn("Face ID backbone [%s] not found!" % id_ckpt)
            raise FileNotFoundError("Face ID backbone [%s] not found!" % id_ckpt)
        self.iresnet.eval()
        self.register_buffer(
            name="trans_matrix",
            tensor=torch.tensor(
                [
                    [
                        [1.07695457, -0.03625215, -1.56352194 * (in_size / 256)],
                        [0.03625215, 1.07695457, -5.32134629 * (in_size / 256)],
                    ]
                ],
                requires_grad=False,
            ).float(),
        )
        self.in_size = in_size

        self.iresnet.requires_grad_(False)

    def forward(self, source, target, infer=False):
        with torch.no_grad():
            ''' 1. get id '''
            if infer:
                resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True)
                id_vector = F.normalize(self.iresnet(resize_input), dim=-1, p=2)
            else:
                M = self.trans_matrix.repeat(source.size()[0], 1, 1)
                source = kornia.geometry.transform.warp_affine(source, M, (self.in_size, self.in_size))

                # import cv2
                # from tricks import Trick
                # cv2.imwrite('warpped_source.png', Trick.tensor_to_arr(source)[0, :, :, ::-1])

                resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True)
                id_vector = F.normalize(self.iresnet(resize_input), dim=-1, p=2)

            ''' 2. get mouth feature '''
            if self.use_mouth_net:
                w1, h1, w2, h2 = self.mouth_crop_param
                mouth_input = resize_input[:, :, h1:h2, w1:w2]  # 112->mouth
                mouth_feat = self.mouth_net(mouth_input)
                id_vector = torch.cat([id_vector, mouth_feat], dim=-1)  # (B,dim_id+dim_mouth)

        x, att = self.G(target, id_vector)
        return x, id_vector, att

    def get_recon(self):
        return self.G.get_recon_tensor()

    def get_att(self, x):
        return self.G.get_attr(x)


class FSHearNet(nn.Module):
    def __init__(self, aei_path: str):
        super(FSHearNet, self).__init__()
        ''' Stage I. AEI_Net '''
        self.aei = FSGenerator(
            id_ckpt=make_abs_path("../../modules/third_party/arcface/weights/ms1mv3_arcface_r100_fp16/backbone.pth")
        ).requires_grad_(False)
        print('Loading pre-trained AEI-Net from %s...' % aei_path)
        self._load_pretrained_aei(aei_path)
        print('Loaded.')

        ''' Stage II. HEAR_Net '''
        self.hear = Hear_Net()

    def _load_pretrained_aei(self, path: str):
        if '.ckpt' in path:
            from trainer.faceshifter.extract_ckpt import extract_generator
            pth_folder = make_abs_path('../../trainer/faceshifter/extracted_ckpt')
            pth_name = 'hear_tmp.pth'
            assert '.pth' in pth_name
            state_dict = extract_generator(load_path=path, path=os.path.join(pth_folder, pth_name))
            self.aei.load_state_dict(state_dict, strict=False)
            self.aei.eval()
        elif '.pth' in path:
            self.aei.load_state_dict(torch.load(path, "cpu"), strict=False)
            self.aei.eval()
        else:
            raise FileNotFoundError('%s (.ckpt or .pth) not found.' % path)

    def forward(self, source, target):
        with torch.no_grad():
            y_hat_st, _, _ = self.aei(source, target, infer=True)
            y_hat_tt, _, _ = self.aei(target, target, infer=True)
            delta_y_t = target - y_hat_tt
            y_cat = torch.cat([y_hat_st, delta_y_t], dim=1)  # (B,6,256,256)

        y_st = self.hear(y_cat)

        return y_st, y_hat_st  # both (B,3,256,256)


if __name__ == '__main__':

    source = torch.randn(8, 3, 512, 512)
    target = torch.randn(8, 3, 512, 512)
    net = FSGenerator(
        id_ckpt="/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/checkpoints/"
                "face_id/ms1mv3_arcface_r100_fp16_backbone.pth",
        mouth_net_param={
            'use': False
        }
    )
    result, _, _ = net(source, target)
    print('result:', result.shape)

    # stage2 = FSHearNet(
    #     aei_path=make_abs_path("../../trainer/faceshifter/out/faceshifter_vanilla/epoch=32-step=509999.ckpt")
    # )
    # final_out, _ = stage2(source, target)
    # print('final out:', final_out.shape)