File size: 10,618 Bytes
a104d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: fs_model_fix_idnorm_donggp_saveoptim copy.py
# Created Date: Wednesday January 12th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified:  Thursday, 21st April 2022 8:13:37 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################


import torch
import torch.nn as nn

from modules.layers.simswap.base_model import BaseModel
from modules.layers.simswap.fs_networks_fix import Generator_Adain_Upsample

from modules.layers.simswap.pg_modules.projected_discriminator import ProjectedDiscriminator


def compute_grad2(d_out, x_in):
    batch_size = x_in.size(0)
    grad_dout = torch.autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = grad_dout2.view(batch_size, -1).sum(1)
    return reg


class fsModel(BaseModel):
    def name(self):
        return 'fsModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # if opt.resize_or_crop != 'none' or not opt.isTrain:  # when training at full res this causes OOM
        self.isTrain = opt.isTrain

        # Generator network
        self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep)
        self.netG.cuda()

        # Id network
        from third_party.arcface import iresnet100
        netArc_pth = "/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/" \
                     "checkpoints/face_id/ms1mv3_arcface_r100_fp16_backbone.pth"  #opt.Arc_path
        self.netArc = iresnet100(pretrained=False, fp16=False)
        self.netArc.load_state_dict(torch.load(netArc_pth, map_location="cpu"))
        # netArc_checkpoint = opt.Arc_path
        # netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
        # self.netArc = netArc_checkpoint['model'].module
        self.netArc = self.netArc.cuda()
        self.netArc.eval()
        self.netArc.requires_grad_(False)
        if not self.isTrain:
            pretrained_path =  opt.checkpoints_dir
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            return
        self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
        # self.netD.feature_network.requires_grad_(False)
        self.netD.cuda()


        if self.isTrain:
            # define loss functions
            self.criterionFeat  = nn.L1Loss()
            self.criterionRec   = nn.L1Loss()

            # initialize optimizers
            # optimizer G
            params = list(self.netG.parameters())
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)

            # optimizer D
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)

        # load networks
        if opt.continue_train:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            # print (pretrained_path)
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
            self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
            self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
        torch.cuda.empty_cache()

    def cosin_metric(self, x1, x2):
        #return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
        return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))

    def save(self, which_epoch):
        self.save_network(self.netG, 'G', which_epoch)
        self.save_network(self.netD, 'D', which_epoch)
        self.save_optim(self.optimizer_G, 'G', which_epoch)
        self.save_optim(self.optimizer_D, 'D', which_epoch)
        '''if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''

    def update_fixed_params(self):
        raise ValueError('Not used')
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())
        self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print('------------ Now also finetuning global generator -----------')

    def update_learning_rate(self):
        raise ValueError('Not used')
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr


if __name__ == "__main__":
    import os
    import argparse

    def str2bool(v):
        return v.lower() in ('true')


    class TrainOptions:
        def __init__(self):
            self.parser = argparse.ArgumentParser()
            self.initialized = False

        def initialize(self):
            self.parser.add_argument('--name', type=str, default='simswap',
                                     help='name of the experiment. It decides where to store samples and models')
            self.parser.add_argument('--gpu_ids', default='0')
            self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints',
                                     help='models are saved here')
            self.parser.add_argument('--isTrain', type=str2bool, default='True')

            # input/output sizes
            self.parser.add_argument('--batchSize', type=int, default=8, help='input batch size')

            # for displays
            self.parser.add_argument('--use_tensorboard', type=str2bool, default='False')

            # for training
            self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2",
                                     help='path to the face swapping dataset')
            self.parser.add_argument('--continue_train', type=str2bool, default='False',
                                     help='continue training: load the latest model')
            self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/simswap224_test',
                                     help='load the pretrained model from the specified location')
            self.parser.add_argument('--which_epoch', type=str, default='10000',
                                     help='which epoch to load? set to latest to use latest cached model')
            self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
            self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate')
            self.parser.add_argument('--niter_decay', type=int, default=10000,
                                     help='# of iter to linearly decay learning rate to zero')
            self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
            self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam')
            self.parser.add_argument('--Gdeep', type=str2bool, default='False')

            # for discriminators
            self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
            self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss')
            self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')

            self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar',
                                     help="run ONNX model via TRT")
            self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
            self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information')
            self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
            self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model')

            self.isTrain = True

        def parse(self, save=True):
            if not self.initialized:
                self.initialize()
            self.opt = self.parser.parse_args()
            self.opt.isTrain = self.isTrain  # train or test

            args = vars(self.opt)

            print('------------ Options -------------')
            for k, v in sorted(args.items()):
                print('%s: %s' % (str(k), str(v)))
            print('-------------- End ----------------')

            # save to the disk
            # if self.opt.isTrain:
            #     expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
            #     util.mkdirs(expr_dir)
            #     if save and not self.opt.continue_train:
            #         file_name = os.path.join(expr_dir, 'opt.txt')
            #         with open(file_name, 'wt') as opt_file:
            #             opt_file.write('------------ Options -------------\n')
            #             for k, v in sorted(args.items()):
            #                 opt_file.write('%s: %s\n' % (str(k), str(v)))
            #             opt_file.write('-------------- End ----------------\n')
            return self.opt

    source = torch.randn(8, 3, 256, 256).cuda()
    target = torch.randn(8, 3, 256, 256).cuda()

    opt = TrainOptions().parse()
    model = fsModel()
    model.initialize(opt)

    import torch.nn.functional as F
    img_id_112 = F.interpolate(source, size=(112, 112), mode='bicubic')
    latent_id = model.netArc(img_id_112)
    latent_id = F.normalize(latent_id, p=2, dim=1)

    img_fake = model.netG(target, latent_id)
    gen_logits, _ = model.netD(img_fake.detach(), None)
    loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()

    real_logits, _ = model.netD(source, None)

    print('img_fake:', img_fake.shape, 'real_logits:', real_logits.shape)