gavinyuan
udpate: app.py import FSGenerator
a104d3f
raw
history blame
No virus
10.6 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: fs_model_fix_idnorm_donggp_saveoptim copy.py
# Created Date: Wednesday January 12th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 21st April 2022 8:13:37 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
from modules.layers.simswap.base_model import BaseModel
from modules.layers.simswap.fs_networks_fix import Generator_Adain_Upsample
from modules.layers.simswap.pg_modules.projected_discriminator import ProjectedDiscriminator
def compute_grad2(d_out, x_in):
batch_size = x_in.size(0)
grad_dout = torch.autograd.grad(
outputs=d_out.sum(), inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.size() == x_in.size())
reg = grad_dout2.view(batch_size, -1).sum(1)
return reg
class fsModel(BaseModel):
def name(self):
return 'fsModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
# if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
self.isTrain = opt.isTrain
# Generator network
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep)
self.netG.cuda()
# Id network
from third_party.arcface import iresnet100
netArc_pth = "/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/" \
"checkpoints/face_id/ms1mv3_arcface_r100_fp16_backbone.pth" #opt.Arc_path
self.netArc = iresnet100(pretrained=False, fp16=False)
self.netArc.load_state_dict(torch.load(netArc_pth, map_location="cpu"))
# netArc_checkpoint = opt.Arc_path
# netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
# self.netArc = netArc_checkpoint['model'].module
self.netArc = self.netArc.cuda()
self.netArc.eval()
self.netArc.requires_grad_(False)
if not self.isTrain:
pretrained_path = opt.checkpoints_dir
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
return
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
# self.netD.feature_network.requires_grad_(False)
self.netD.cuda()
if self.isTrain:
# define loss functions
self.criterionFeat = nn.L1Loss()
self.criterionRec = nn.L1Loss()
# initialize optimizers
# optimizer G
params = list(self.netG.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
# load networks
if opt.continue_train:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
# print (pretrained_path)
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
torch.cuda.empty_cache()
def cosin_metric(self, x1, x2):
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch)
self.save_network(self.netD, 'D', which_epoch)
self.save_optim(self.optimizer_G, 'G', which_epoch)
self.save_optim(self.optimizer_D, 'D', which_epoch)
'''if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
def update_fixed_params(self):
raise ValueError('Not used')
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
raise ValueError('Not used')
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
if __name__ == "__main__":
import os
import argparse
def str2bool(v):
return v.lower() in ('true')
class TrainOptions:
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
self.parser.add_argument('--name', type=str, default='simswap',
help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--gpu_ids', default='0')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints',
help='models are saved here')
self.parser.add_argument('--isTrain', type=str2bool, default='True')
# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
# for displays
self.parser.add_argument('--use_tensorboard', type=str2bool, default='False')
# for training
self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2",
help='path to the face swapping dataset')
self.parser.add_argument('--continue_train', type=str2bool, default='False',
help='continue training: load the latest model')
self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/simswap224_test',
help='load the pretrained model from the specified location')
self.parser.add_argument('--which_epoch', type=str, default='10000',
help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=10000,
help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam')
self.parser.add_argument('--Gdeep', type=str2bool, default='False')
# for discriminators
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss')
self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')
self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar',
help="run ONNX model via TRT")
self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information')
self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model')
self.isTrain = True
def parse(self, save=True):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
# if self.opt.isTrain:
# expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
# util.mkdirs(expr_dir)
# if save and not self.opt.continue_train:
# file_name = os.path.join(expr_dir, 'opt.txt')
# with open(file_name, 'wt') as opt_file:
# opt_file.write('------------ Options -------------\n')
# for k, v in sorted(args.items()):
# opt_file.write('%s: %s\n' % (str(k), str(v)))
# opt_file.write('-------------- End ----------------\n')
return self.opt
source = torch.randn(8, 3, 256, 256).cuda()
target = torch.randn(8, 3, 256, 256).cuda()
opt = TrainOptions().parse()
model = fsModel()
model.initialize(opt)
import torch.nn.functional as F
img_id_112 = F.interpolate(source, size=(112, 112), mode='bicubic')
latent_id = model.netArc(img_id_112)
latent_id = F.normalize(latent_id, p=2, dim=1)
img_fake = model.netG(target, latent_id)
gen_logits, _ = model.netD(img_fake.detach(), None)
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
real_logits, _ = model.netD(source, None)
print('img_fake:', img_fake.shape, 'real_logits:', real_logits.shape)