Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
############################################################# | |
# File: fs_model_fix_idnorm_donggp_saveoptim copy.py | |
# Created Date: Wednesday January 12th 2022 | |
# Author: Chen Xuanhong | |
# Email: chenxuanhongzju@outlook.com | |
# Last Modified: Thursday, 21st April 2022 8:13:37 pm | |
# Modified By: Chen Xuanhong | |
# Copyright (c) 2022 Shanghai Jiao Tong University | |
############################################################# | |
import torch | |
import torch.nn as nn | |
from modules.layers.simswap.base_model import BaseModel | |
from modules.layers.simswap.fs_networks_fix import Generator_Adain_Upsample | |
from modules.layers.simswap.pg_modules.projected_discriminator import ProjectedDiscriminator | |
def compute_grad2(d_out, x_in): | |
batch_size = x_in.size(0) | |
grad_dout = torch.autograd.grad( | |
outputs=d_out.sum(), inputs=x_in, | |
create_graph=True, retain_graph=True, only_inputs=True | |
)[0] | |
grad_dout2 = grad_dout.pow(2) | |
assert(grad_dout2.size() == x_in.size()) | |
reg = grad_dout2.view(batch_size, -1).sum(1) | |
return reg | |
class fsModel(BaseModel): | |
def name(self): | |
return 'fsModel' | |
def initialize(self, opt): | |
BaseModel.initialize(self, opt) | |
# if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM | |
self.isTrain = opt.isTrain | |
# Generator network | |
self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep) | |
self.netG.cuda() | |
# Id network | |
from third_party.arcface import iresnet100 | |
netArc_pth = "/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/" \ | |
"checkpoints/face_id/ms1mv3_arcface_r100_fp16_backbone.pth" #opt.Arc_path | |
self.netArc = iresnet100(pretrained=False, fp16=False) | |
self.netArc.load_state_dict(torch.load(netArc_pth, map_location="cpu")) | |
# netArc_checkpoint = opt.Arc_path | |
# netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu")) | |
# self.netArc = netArc_checkpoint['model'].module | |
self.netArc = self.netArc.cuda() | |
self.netArc.eval() | |
self.netArc.requires_grad_(False) | |
if not self.isTrain: | |
pretrained_path = opt.checkpoints_dir | |
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) | |
return | |
self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{}) | |
# self.netD.feature_network.requires_grad_(False) | |
self.netD.cuda() | |
if self.isTrain: | |
# define loss functions | |
self.criterionFeat = nn.L1Loss() | |
self.criterionRec = nn.L1Loss() | |
# initialize optimizers | |
# optimizer G | |
params = list(self.netG.parameters()) | |
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8) | |
# optimizer D | |
params = list(self.netD.parameters()) | |
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8) | |
# load networks | |
if opt.continue_train: | |
pretrained_path = '' if not self.isTrain else opt.load_pretrain | |
# print (pretrained_path) | |
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) | |
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) | |
self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path) | |
self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path) | |
torch.cuda.empty_cache() | |
def cosin_metric(self, x1, x2): | |
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2)) | |
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1)) | |
def save(self, which_epoch): | |
self.save_network(self.netG, 'G', which_epoch) | |
self.save_network(self.netD, 'D', which_epoch) | |
self.save_optim(self.optimizer_G, 'G', which_epoch) | |
self.save_optim(self.optimizer_D, 'D', which_epoch) | |
'''if self.gen_features: | |
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)''' | |
def update_fixed_params(self): | |
raise ValueError('Not used') | |
# after fixing the global generator for a number of iterations, also start finetuning it | |
params = list(self.netG.parameters()) | |
if self.gen_features: | |
params += list(self.netE.parameters()) | |
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) | |
if self.opt.verbose: | |
print('------------ Now also finetuning global generator -----------') | |
def update_learning_rate(self): | |
raise ValueError('Not used') | |
lrd = self.opt.lr / self.opt.niter_decay | |
lr = self.old_lr - lrd | |
for param_group in self.optimizer_D.param_groups: | |
param_group['lr'] = lr | |
for param_group in self.optimizer_G.param_groups: | |
param_group['lr'] = lr | |
if self.opt.verbose: | |
print('update learning rate: %f -> %f' % (self.old_lr, lr)) | |
self.old_lr = lr | |
if __name__ == "__main__": | |
import os | |
import argparse | |
def str2bool(v): | |
return v.lower() in ('true') | |
class TrainOptions: | |
def __init__(self): | |
self.parser = argparse.ArgumentParser() | |
self.initialized = False | |
def initialize(self): | |
self.parser.add_argument('--name', type=str, default='simswap', | |
help='name of the experiment. It decides where to store samples and models') | |
self.parser.add_argument('--gpu_ids', default='0') | |
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', | |
help='models are saved here') | |
self.parser.add_argument('--isTrain', type=str2bool, default='True') | |
# input/output sizes | |
self.parser.add_argument('--batchSize', type=int, default=8, help='input batch size') | |
# for displays | |
self.parser.add_argument('--use_tensorboard', type=str2bool, default='False') | |
# for training | |
self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2", | |
help='path to the face swapping dataset') | |
self.parser.add_argument('--continue_train', type=str2bool, default='False', | |
help='continue training: load the latest model') | |
self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/simswap224_test', | |
help='load the pretrained model from the specified location') | |
self.parser.add_argument('--which_epoch', type=str, default='10000', | |
help='which epoch to load? set to latest to use latest cached model') | |
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') | |
self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate') | |
self.parser.add_argument('--niter_decay', type=int, default=10000, | |
help='# of iter to linearly decay learning rate to zero') | |
self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') | |
self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam') | |
self.parser.add_argument('--Gdeep', type=str2bool, default='False') | |
# for discriminators | |
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') | |
self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss') | |
self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss') | |
self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', | |
help="run ONNX model via TRT") | |
self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step') | |
self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information') | |
self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling') | |
self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model') | |
self.isTrain = True | |
def parse(self, save=True): | |
if not self.initialized: | |
self.initialize() | |
self.opt = self.parser.parse_args() | |
self.opt.isTrain = self.isTrain # train or test | |
args = vars(self.opt) | |
print('------------ Options -------------') | |
for k, v in sorted(args.items()): | |
print('%s: %s' % (str(k), str(v))) | |
print('-------------- End ----------------') | |
# save to the disk | |
# if self.opt.isTrain: | |
# expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) | |
# util.mkdirs(expr_dir) | |
# if save and not self.opt.continue_train: | |
# file_name = os.path.join(expr_dir, 'opt.txt') | |
# with open(file_name, 'wt') as opt_file: | |
# opt_file.write('------------ Options -------------\n') | |
# for k, v in sorted(args.items()): | |
# opt_file.write('%s: %s\n' % (str(k), str(v))) | |
# opt_file.write('-------------- End ----------------\n') | |
return self.opt | |
source = torch.randn(8, 3, 256, 256).cuda() | |
target = torch.randn(8, 3, 256, 256).cuda() | |
opt = TrainOptions().parse() | |
model = fsModel() | |
model.initialize(opt) | |
import torch.nn.functional as F | |
img_id_112 = F.interpolate(source, size=(112, 112), mode='bicubic') | |
latent_id = model.netArc(img_id_112) | |
latent_id = F.normalize(latent_id, p=2, dim=1) | |
img_fake = model.netG(target, latent_id) | |
gen_logits, _ = model.netD(img_fake.detach(), None) | |
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean() | |
real_logits, _ = model.netD(source, None) | |
print('img_fake:', img_fake.shape, 'real_logits:', real_logits.shape) | |